Skip to content

Commit caa5f05

Browse files
committed
Add GW150914_NRSur7dq4.py example script
1 parent 373a7a7 commit caa5f05

1 file changed

Lines changed: 256 additions & 0 deletions

File tree

example/GW150914_NRSur7dq4.py

Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
import time
2+
import jax
3+
import jax.numpy as jnp
4+
from jimgw.core.jim import Jim
5+
from jimgw.core.prior import (
6+
CombinePrior,
7+
UniformPrior,
8+
CosinePrior,
9+
SinePrior,
10+
PowerLawPrior,
11+
UniformSpherePrior,
12+
)
13+
from jimgw.core.single_event.detector import get_H1, get_L1
14+
from jimgw.core.single_event.likelihood import BaseTransientLikelihoodFD
15+
from jimgw.core.single_event.data import Data
16+
from jimgw.core.single_event.waveform import JaxNRSur7dq4
17+
from jimgw.core.transforms import BoundToUnbound
18+
from jimgw.core.single_event.transforms import (
19+
SkyFrameToDetectorFrameSkyPositionTransform,
20+
SphereSpinToCartesianSpinTransform,
21+
MassRatioToSymmetricMassRatioTransform,
22+
DistanceToSNRWeightedDistanceTransform,
23+
GeocentricArrivalTimeToDetectorArrivalTimeTransform,
24+
GeocentricArrivalPhaseToDetectorArrivalPhaseTransform,
25+
)
26+
27+
jax.config.update("jax_enable_x64", True)
28+
29+
###########################################
30+
########## First we grab data #############
31+
###########################################
32+
33+
total_time_start = time.time()
34+
35+
# first, fetch a 4s segment centered on GW150914
36+
# for the analysis
37+
gps = 1126259462.4
38+
start = gps - 2
39+
end = gps + 2
40+
41+
# fetch 4096s of data to estimate the PSD (to be
42+
# careful we should avoid the on-source segment,
43+
# but we don't do this in this example)
44+
psd_start = gps - 2048
45+
psd_end = gps + 2048
46+
47+
# define frequency integration bounds for the likelihood
48+
# we set fmax to 87.5% of the Nyquist frequency to avoid
49+
# data corrupted by the GWOSC antialiasing filter
50+
# (Note that Data.from_gwosc will pull data sampled at
51+
# 4096 Hz by default)
52+
fmin = 20.0
53+
fmax = 1024
54+
55+
# initialize detectors
56+
ifos = [get_H1(), get_L1()]
57+
58+
for ifo in ifos:
59+
# set analysis data
60+
data = Data.from_gwosc(ifo.name, start, end)
61+
ifo.set_data(data)
62+
63+
# set PSD (Welch estimate)
64+
psd_data = Data.from_gwosc(ifo.name, psd_start, psd_end)
65+
# set an NFFT corresponding to the analysis segment duration
66+
psd_fftlength = data.duration * data.sampling_frequency
67+
ifo.set_psd(psd_data.to_psd(nperseg=psd_fftlength))
68+
69+
###########################################
70+
########## Set up waveform ################
71+
###########################################
72+
73+
# initialize waveform
74+
waveform = JaxNRSur7dq4(segment_length=ifos[0].data.duration, sampling_rate=int(ifos[0].data.sampling_frequency))
75+
76+
###########################################
77+
########## Set up priors ##################
78+
###########################################
79+
80+
prior = []
81+
82+
# Mass prior
83+
M_c_min, M_c_max = 10.0, 80.0
84+
q_min, q_max = 0.125, 1.0
85+
Mc_prior = UniformPrior(M_c_min, M_c_max, parameter_names=["M_c"])
86+
q_prior = UniformPrior(q_min, q_max, parameter_names=["q"])
87+
88+
prior = prior + [Mc_prior, q_prior]
89+
90+
# Spin prior
91+
s1_prior = UniformSpherePrior(parameter_names=["s1"])
92+
s2_prior = UniformSpherePrior(parameter_names=["s2"])
93+
iota_prior = SinePrior(parameter_names=["iota"])
94+
95+
prior = prior + [
96+
s1_prior,
97+
s2_prior,
98+
iota_prior,
99+
]
100+
101+
# Extrinsic prior
102+
dL_prior = PowerLawPrior(1.0, 2000.0, 2.0, parameter_names=["d_L"])
103+
t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"])
104+
phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"])
105+
psi_prior = UniformPrior(0.0, jnp.pi, parameter_names=["psi"])
106+
ra_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"])
107+
dec_prior = CosinePrior(parameter_names=["dec"])
108+
109+
prior = prior + [
110+
dL_prior,
111+
t_c_prior,
112+
phase_c_prior,
113+
psi_prior,
114+
ra_prior,
115+
dec_prior,
116+
]
117+
118+
prior = CombinePrior(prior)
119+
120+
# Defining Transforms
121+
122+
sample_transforms = [
123+
DistanceToSNRWeightedDistanceTransform(
124+
gps_time=gps, ifos=ifos, dL_min=dL_prior.xmin, dL_max=dL_prior.xmax
125+
),
126+
GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(gps_time=gps, ifo=ifos[0]),
127+
GeocentricArrivalTimeToDetectorArrivalTimeTransform(
128+
tc_min=t_c_prior.xmin, tc_max=t_c_prior.xmax, gps_time=gps, ifo=ifos[0]
129+
),
130+
SkyFrameToDetectorFrameSkyPositionTransform(gps_time=gps, ifos=ifos),
131+
BoundToUnbound(
132+
name_mapping=(["M_c"], ["M_c_unbounded"]),
133+
original_lower_bound=M_c_min,
134+
original_upper_bound=M_c_max,
135+
),
136+
BoundToUnbound(
137+
name_mapping=(["q"], ["q_unbounded"]),
138+
original_lower_bound=q_min,
139+
original_upper_bound=q_max,
140+
),
141+
BoundToUnbound(
142+
name_mapping=(["s1_phi"], ["s1_phi_unbounded"]),
143+
original_lower_bound=0.0,
144+
original_upper_bound=2 * jnp.pi,
145+
),
146+
BoundToUnbound(
147+
name_mapping=(["s2_phi"], ["s2_phi_unbounded"]),
148+
original_lower_bound=0.0,
149+
original_upper_bound=2 * jnp.pi,
150+
),
151+
BoundToUnbound(
152+
name_mapping=(["iota"], ["iota_unbounded"]),
153+
original_lower_bound=0.0,
154+
original_upper_bound=jnp.pi,
155+
),
156+
BoundToUnbound(
157+
name_mapping=(["s1_theta"], ["s1_theta_unbounded"]),
158+
original_lower_bound=0.0,
159+
original_upper_bound=jnp.pi,
160+
),
161+
BoundToUnbound(
162+
name_mapping=(["s2_theta"], ["s2_theta_unbounded"]),
163+
original_lower_bound=0.0,
164+
original_upper_bound=jnp.pi,
165+
),
166+
BoundToUnbound(
167+
name_mapping=(["s1_mag"], ["s1_mag_unbounded"]),
168+
original_lower_bound=0.0,
169+
original_upper_bound=0.99,
170+
),
171+
BoundToUnbound(
172+
name_mapping=(["s2_mag"], ["s2_mag_unbounded"]),
173+
original_lower_bound=0.0,
174+
original_upper_bound=0.99,
175+
),
176+
BoundToUnbound(
177+
name_mapping=(["phase_det"], ["phase_det_unbounded"]),
178+
original_lower_bound=0.0,
179+
original_upper_bound=2 * jnp.pi,
180+
),
181+
BoundToUnbound(
182+
name_mapping=(["psi"], ["psi_unbounded"]),
183+
original_lower_bound=0.0,
184+
original_upper_bound=jnp.pi,
185+
),
186+
BoundToUnbound(
187+
name_mapping=(["zenith"], ["zenith_unbounded"]),
188+
original_lower_bound=0.0,
189+
original_upper_bound=jnp.pi,
190+
),
191+
BoundToUnbound(
192+
name_mapping=(["azimuth"], ["azimuth_unbounded"]),
193+
original_lower_bound=0.0,
194+
original_upper_bound=2 * jnp.pi,
195+
),
196+
]
197+
198+
likelihood_transforms = [
199+
MassRatioToSymmetricMassRatioTransform,
200+
SphereSpinToCartesianSpinTransform("s1"),
201+
SphereSpinToCartesianSpinTransform("s2"),
202+
]
203+
204+
205+
likelihood = BaseTransientLikelihoodFD(
206+
ifos,
207+
waveform=waveform,
208+
trigger_time=gps,
209+
f_min=fmin,
210+
f_max=fmax,
211+
)
212+
213+
jim = Jim(
214+
likelihood,
215+
prior,
216+
sample_transforms=sample_transforms,
217+
likelihood_transforms=likelihood_transforms,
218+
n_chains=500,
219+
n_local_steps=100,
220+
n_global_steps=1000,
221+
n_training_loops=20,
222+
n_production_loops=10,
223+
n_epochs=20,
224+
mala_step_size=2e-3,
225+
rq_spline_hidden_units=[128, 128],
226+
rq_spline_n_bins=10,
227+
rq_spline_n_layers=8,
228+
learning_rate=1e-3,
229+
batch_size=10000,
230+
n_max_examples=30000,
231+
n_NFproposal_batch_size=100,
232+
local_thinning=1,
233+
global_thinning=100,
234+
history_window=200,
235+
n_temperatures=0,
236+
max_temperature=20.0,
237+
n_tempered_steps=10,
238+
verbose=True,
239+
)
240+
#
241+
jim.sample()
242+
243+
print("Done!")
244+
245+
logprob = jim.sampler.resources["log_prob_production"].data
246+
print(jnp.mean(logprob))
247+
248+
chains = jim.get_samples()
249+
250+
import numpy as np
251+
import corner
252+
253+
fig = corner.corner(
254+
np.stack([chains[key] for key in jim.prior.parameter_names]).T[::10]
255+
)
256+
fig.savefig("test")

0 commit comments

Comments
 (0)