Skip to content

Commit 849bb71

Browse files
committed
flake
1 parent 5ea4072 commit 849bb71

6 files changed

Lines changed: 41 additions & 40 deletions

File tree

enterprise_extensions/blocks.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -170,10 +170,10 @@ def white_noise_block(
170170
else:
171171
if name is None:
172172
ec = white_signals.EcorrKernelNoise(log10_ecorr=ecorr,
173-
selection=backend_ng, name=name)
173+
selection=selection_ecorr)
174174
else:
175175
ec = white_signals.EcorrKernelNoise(log10_ecorr=ecorr,
176-
selection=backend_ng)
176+
selection=selection_ecorr, name=name)
177177
# combine signals
178178
if inc_ecorr:
179179
s = efeq + ec
@@ -670,7 +670,7 @@ def dm_noise_block(
670670
elif prior == "gaussian":
671671
log10_A_dm = parameter.Normal(logmin, logmax)
672672
elif not vary:
673-
log10_A_dm = parameter.Constant()
673+
log10_A_dm = parameter.Constant()
674674
else:
675675
if prior == "uniform":
676676
log10_A_dm = parameter.LinearExp(-20, -10)
@@ -783,10 +783,10 @@ def dm_noise_block(
783783
log10_rho_dm = parameter.Uniform(-10, -4, size=components)
784784
else:
785785
log10_rho_dm = parameter.Uniform(-9, -4, size=components)
786-
if not vary: # here just overwrite the prior to constant if not varying it
786+
if not vary: # here just overwrite the prior to constant if not varying it
787787
log10_rho_dm = parameter.Constant()
788788

789-
dm_prior = gpp.free_spectrum(log10_rho=log10_rho)
789+
dm_prior = gpp.free_spectrum(log10_rho=log10_rho_dm)
790790

791791
if tndm:
792792
dm_basis = utils.createfourierdesignmatrix_dm_tn(
@@ -893,7 +893,7 @@ def dm_noise_block(
893893
log10_sigma_ridge = parameter.Constant()
894894

895895
dm_basis = gpk.linear_interp_basis_dm(dt=dt * const.day)
896-
dm_prior = gpk.dmx_ridge_prior(log10_sigma=log10_sigma)
896+
dm_prior = gpk.dmx_ridge_prior(log10_sigma=log10_sigma_ridge)
897897

898898
if select is None:
899899
dmgp = gp_signals.BasisGP(
@@ -1110,7 +1110,7 @@ def chromatic_noise_block(
11101110
if vary:
11111111
log10_B = parameter.Uniform(-10, -4)
11121112
else:
1113-
log10B = parameter.Constant()
1113+
log10_B = parameter.Constant()
11141114
chm_prior = gpp.flat_powerlaw(
11151115
log10_A=log10_A, gamma=gamma, log10_B=log10_B
11161116
)
@@ -1130,7 +1130,7 @@ def chromatic_noise_block(
11301130
log10_rho = parameter.Uniform(-9, -4, size=components)
11311131
if vary:
11321132
log10_rho = parameter.Constant()
1133-
1133+
11341134
chm_prior = gpp.free_spectrum(log10_rho=log10_rho)
11351135

11361136
if tndm:

enterprise_extensions/chromatic/chromatic.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import numpy as np
44
from enterprise import constants as const
55
from enterprise.signals import deterministic_signals, parameter, signal_base, gp_bases
6-
from .solar_wind import solar_wind
76

87
__all__ = [
98
"chrom_exp_decay",
@@ -429,11 +428,11 @@ def dmx_signal(dmx_data, name="dmx_signal", vary=True):
429428
for dmx_id in sorted(dmx_data):
430429
dmx_data_tmp = dmx_data[dmx_id]
431430
dmx.update(
432-
{
433-
dmx_id: parameter.Normal(
434-
mu=dmx_data_tmp["DMX_VAL"], sigma=dmx_data_tmp["DMX_ERR"]
435-
)
436-
}
431+
{
432+
dmx_id: parameter.Normal(
433+
mu=dmx_data_tmp["DMX_VAL"], sigma=dmx_data_tmp["DMX_ERR"]
434+
)
435+
}
437436
)
438437
else:
439438
for dmx_id in sorted(dmx_data):
@@ -446,7 +445,6 @@ def dmx_signal(dmx_data, name="dmx_signal", vary=True):
446445

447446

448447
def dm_annual_signal(idx=2, tmin=None, tmax=None, name="dm_s1yr", vary=True):
449-
def dm_annual_signal(idx=2, name="dm_s1yr"):
450448
"""
451449
Returns chromatic annual signal (i.e. TOA advance):
452450
@@ -487,7 +485,7 @@ def construct_chromatic_cached_parts(
487485
fmax=None,
488486
modes=None,
489487
fref=1400,
490-
):
488+
):
491489
"""
492490
Using this function alongside `createfourierdesignmatrix_chromatic_with_additional_caching()`
493491
enables caching of the achromatic portion of the chromatic Fourier designmatrix as well as caching

enterprise_extensions/chromatic/solar_wind.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -248,9 +248,9 @@ def solar_wind_block(
248248
n_earth = ACE_SWEPAM_Parameter(size=n_earth_bins.size-1)("n_earth")
249249
elif n_earth is None and (isinstance(n_earth_bins, list) or
250250
isinstance(n_earth_bins, np.ndarray)) and not ACE_prior:
251-
n_earth = parameter.Uniform(0,30,size=n_earth_bins.size-1)("n_earth")
251+
n_earth = parameter.Uniform(0, 30, size=n_earth_bins.size-1)("n_earth")
252252
else:
253-
pass # set n_earth to the provided value(s) below
253+
pass # set n_earth to the provided value(s) below
254254

255255
deter_sw = solar_wind(n_earth=n_earth, n_earth_bins=n_earth_bins, t_init=t_init, t_final=t_final)
256256
mean_sw = deterministic_signals.Deterministic(deter_sw, name=det_name)
@@ -262,13 +262,13 @@ def solar_wind_block(
262262
sw_basis = createfourierdesignmatrix_solar_dm(modes=modes)
263263
nmodes = len(modes)
264264
elif Tspan is not None:
265-
sw_basis = createfourierdesignmatrix_solar_dm(nmodes=nmodes,
266-
Tspan=Tspan)
265+
sw_basis = createfourierdesignmatrix_solar_dm(nmodes=nmodes,
266+
Tspan=Tspan)
267267
if swgp_prior == "powerlaw":
268268
if vary_swgp:
269269
# sometimes amplitudes larger than 1 break the likelihood
270-
log10_A_sw = parameter.Uniform(-12, 0) # sometimes positive amplitudes break this
271-
gamma_sw = parameter.Uniform(-6, 5) # priors from susurla et al. 2024
270+
log10_A_sw = parameter.Uniform(-12, 0) # sometimes positive amplitudes break this
271+
gamma_sw = parameter.Uniform(-6, 5) # priors from susurla et al. 2024
272272
else:
273273
log10_A_sw = parameter.Constant()
274274
gamma_sw = parameter.Constant()
@@ -292,9 +292,9 @@ def solar_wind_block(
292292
if swgp_prior == "periodic":
293293
# Periodic GP kernel for DM
294294
if vary_swgp:
295-
log10_sigma = parameter.Uniform(-10, -4) # units are log10(seconds)
296-
log10_ell = parameter.Uniform(1, 4) # units are log10(days)
297-
log10_p = parameter.Uniform(-4, 1.5) # units are log10(years)
295+
log10_sigma = parameter.Uniform(-10, -4) # units are log10(seconds)
296+
log10_ell = parameter.Uniform(1, 4) # units are log10(days)
297+
log10_p = parameter.Uniform(-4, 1.5) # units are log10(years)
298298
log10_gam_p = parameter.Uniform(-3, 2)
299299
else:
300300
log10_sigma = parameter.Constant()
@@ -303,9 +303,9 @@ def solar_wind_block(
303303
log10_gam_p = parameter.Constant()
304304

305305
sw_prior = gpk.periodic_kernel(log10_sigma=log10_sigma,
306-
log10_ell=log10_ell,
307-
log10_gam_p=log10_gam_p,
308-
log10_p=log10_p)
306+
log10_ell=log10_ell,
307+
log10_gam_p=log10_gam_p,
308+
log10_p=log10_p)
309309
elif swgp_prior == "sq_exp":
310310
# squared-exponential GP kernel for DM
311311
if vary_swgp:
@@ -338,7 +338,7 @@ def solar_wind_block(
338338
sw_prior = gpk.sw_dm_wn_prior(log10_sigma_ne=log10_sigma_ne)
339339
else:
340340
raise ValueError("Invalid triangular-basis SWGP prior specified.")
341-
341+
342342
else:
343343
raise ValueError("Invalid SWGP basis specified.")
344344

enterprise_extensions/models.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from enterprise_extensions.chromatic.solar_wind import solar_wind_block
3232
from enterprise_extensions.timing import timing_block
3333

34-
# from enterprise.signals.signal_base import LookupLikelihood
3534

3635
def model_singlepsr_noise(
3736
psr,
@@ -45,6 +44,7 @@ def model_singlepsr_noise(
4544
tm_svd=False,
4645
tm_norm=True,
4746
white_vary=True,
47+
gp_ecorr=False,
4848
components=30,
4949
upper_limit=False,
5050
is_wideband=False,
@@ -53,6 +53,7 @@ def model_singlepsr_noise(
5353
dmjump_var=False,
5454
gamma_val=None,
5555
dm_var=False,
56+
vary_dm=True,
5657
dm_type="gp",
5758
dmgp_kernel="diag",
5859
dm_psd="powerlaw",
@@ -64,6 +65,7 @@ def model_singlepsr_noise(
6465
dm_df=200,
6566
dm_Nfreqs=100,
6667
chrom_gp=False,
68+
vary_chrom=True,
6769
chrom_gp_kernel="nondiag",
6870
chrom_psd="powerlaw",
6971
chrom_idx=4,
@@ -72,6 +74,7 @@ def model_singlepsr_noise(
7274
chrom_dt=15,
7375
chrom_df=200,
7476
chrom_Nfreqs=100,
77+
vary_dm_dips=True,
7578
dm_expdip=False,
7679
dmexp_sign="negative",
7780
dm_expdip_idx=2,
@@ -138,6 +141,7 @@ def model_singlepsr_noise(
138141
is_wideband
139142
:param gamma_val: red noise spectral index to fix
140143
:param dm_var: whether to explicitly model DM-variations
144+
:param vary_dm: whether to vary the DM model GP hyperparams or use constant values
141145
:param dm_type: gaussian process ('gp') or dmx ('dmx')
142146
:param dmgp_kernel: diagonal in frequency or non-diagonal
143147
:param dm_psd: power-spectral density of DM variations
@@ -149,6 +153,7 @@ def model_singlepsr_noise(
149153
:param dm_df: frequency-scale for DM linear interpolation basis (MHz)
150154
:param dm_Nfreqs: Number of Fourier modes to use for the dm_gp model.
151155
:param chrom_gp: include general chromatic noise
156+
:param vary_chrom: whether to vary the chromatic GP hyperparams or use constant values
152157
:param chrom_gp_kernel: GP kernel type to use in chrom ['diag','nondiag']
153158
:param chrom_psd: power-spectral density of chromatic noise
154159
['powerlaw','tprocess','free_spectrum']
@@ -161,6 +166,7 @@ def model_singlepsr_noise(
161166
:param chrom_dt: time-scale for chromatic linear interpolation basis (days)
162167
:param chrom_df: frequency-scale for chromatic linear interpolation basis (MHz)
163168
:param chrom_Nfreqs: Number of Fourier modes to use for the chromatic GP.
169+
:param vary_dm_dips: whether to vary the DM dip parameters or keep them fixed
164170
:param dm_expdip: inclue a DM exponential dip
165171
:param dmexp_sign: set the sign parameter for dip
166172
:param dm_expdip_idx: chromatic index of exponential dip
@@ -424,7 +430,7 @@ def model_singlepsr_noise(
424430
sign=dm_dual_cusp_sign,
425431
symmetric=dm_dual_cusp_sym,
426432
name=dual_cusp_name_base + str(dd),
427-
vary=vary_dm)
433+
vary=vary_dm,
428434
)
429435
if dm_sw_deter:
430436
Tspan = psr.toas.max() - psr.toas.min()
@@ -445,8 +451,7 @@ def model_singlepsr_noise(
445451
include_quadratic=chrom_quad,
446452
coefficients=coefficients,
447453
Tspan=Tspan,
448-
vary=vary_chrom,
449-
idx_prior_upper_bound=chrom_gp_idx_prior_upper_bound)
454+
vary=vary_chrom,)
450455
if extra_sigs is not None:
451456
s += extra_sigs
452457

enterprise_extensions/sampler.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -366,11 +366,11 @@ def __init__(
366366
# extend empirical_distr here:
367367
print(f'extending {key}\'s empirical distributions to priors...\n')
368368
self.empirical_distr[key] = extend_emp_dists(
369-
pta,
370-
self.empirical_distr[key],
371-
npoints=100_000,
372-
save_ext_dists=save_ext_dists,
373-
outdir=outdir
369+
pta,
370+
self.empirical_distr[key],
371+
npoints=100_000,
372+
save_ext_dists=save_ext_dists,
373+
outdir=outdir
374374
)
375375
else:
376376
mask = []

tests/test_models.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
import json
77
import logging
88
import os
9-
10-
import pickle
119
import numpy as np
1210

1311
import pytest

0 commit comments

Comments
 (0)