|
| 1 | +import time |
| 2 | +import jax |
| 3 | +import jax.numpy as jnp |
| 4 | +from jimgw.core.jim import Jim |
| 5 | +from jimgw.core.prior import ( |
| 6 | + CombinePrior, |
| 7 | + UniformPrior, |
| 8 | + CosinePrior, |
| 9 | + SinePrior, |
| 10 | + PowerLawPrior, |
| 11 | + UniformSpherePrior, |
| 12 | +) |
| 13 | +from jimgw.core.single_event.detector import get_H1, get_L1 |
| 14 | +from jimgw.core.single_event.likelihood import BaseTransientLikelihoodFD |
| 15 | +from jimgw.core.single_event.data import Data |
| 16 | +from jimgw.core.single_event.waveform import JaxNRSur7dq4 |
| 17 | +from jimgw.core.transforms import BoundToUnbound |
| 18 | +from jimgw.core.single_event.transforms import ( |
| 19 | + SkyFrameToDetectorFrameSkyPositionTransform, |
| 20 | + SphereSpinToCartesianSpinTransform, |
| 21 | + MassRatioToSymmetricMassRatioTransform, |
| 22 | + DistanceToSNRWeightedDistanceTransform, |
| 23 | + GeocentricArrivalTimeToDetectorArrivalTimeTransform, |
| 24 | + GeocentricArrivalPhaseToDetectorArrivalPhaseTransform, |
| 25 | +) |
| 26 | + |
| 27 | +jax.config.update("jax_enable_x64", True) |
| 28 | + |
| 29 | +########################################### |
| 30 | +########## First we grab data ############# |
| 31 | +########################################### |
| 32 | + |
| 33 | +total_time_start = time.time() |
| 34 | + |
| 35 | +# first, fetch a 4s segment centered on GW150914 |
| 36 | +# for the analysis |
| 37 | +gps = 1126259462.4 |
| 38 | +start = gps - 2 |
| 39 | +end = gps + 2 |
| 40 | + |
| 41 | +# fetch 4096s of data to estimate the PSD (to be |
| 42 | +# careful we should avoid the on-source segment, |
| 43 | +# but we don't do this in this example) |
| 44 | +psd_start = gps - 2048 |
| 45 | +psd_end = gps + 2048 |
| 46 | + |
| 47 | +# define frequency integration bounds for the likelihood |
| 48 | +# we set fmax to 87.5% of the Nyquist frequency to avoid |
| 49 | +# data corrupted by the GWOSC antialiasing filter |
| 50 | +# (Note that Data.from_gwosc will pull data sampled at |
| 51 | +# 4096 Hz by default) |
| 52 | +fmin = 20.0 |
| 53 | +fmax = 1024 |
| 54 | + |
| 55 | +# initialize detectors |
| 56 | +ifos = [get_H1(), get_L1()] |
| 57 | + |
| 58 | +for ifo in ifos: |
| 59 | + # set analysis data |
| 60 | + data = Data.from_gwosc(ifo.name, start, end) |
| 61 | + ifo.set_data(data) |
| 62 | + |
| 63 | + # set PSD (Welch estimate) |
| 64 | + psd_data = Data.from_gwosc(ifo.name, psd_start, psd_end) |
| 65 | + # set an NFFT corresponding to the analysis segment duration |
| 66 | + psd_fftlength = data.duration * data.sampling_frequency |
| 67 | + ifo.set_psd(psd_data.to_psd(nperseg=psd_fftlength)) |
| 68 | + |
| 69 | +########################################### |
| 70 | +########## Set up waveform ################ |
| 71 | +########################################### |
| 72 | + |
| 73 | +# initialize waveform |
| 74 | +waveform = JaxNRSur7dq4(segment_length=ifos[0].data.duration, sampling_rate=int(ifos[0].data.sampling_frequency)) |
| 75 | + |
| 76 | +########################################### |
| 77 | +########## Set up priors ################## |
| 78 | +########################################### |
| 79 | + |
| 80 | +prior = [] |
| 81 | + |
| 82 | +# Mass prior |
| 83 | +M_c_min, M_c_max = 10.0, 80.0 |
| 84 | +q_min, q_max = 0.125, 1.0 |
| 85 | +Mc_prior = UniformPrior(M_c_min, M_c_max, parameter_names=["M_c"]) |
| 86 | +q_prior = UniformPrior(q_min, q_max, parameter_names=["q"]) |
| 87 | + |
| 88 | +prior = prior + [Mc_prior, q_prior] |
| 89 | + |
| 90 | +# Spin prior |
| 91 | +s1_prior = UniformSpherePrior(parameter_names=["s1"]) |
| 92 | +s2_prior = UniformSpherePrior(parameter_names=["s2"]) |
| 93 | +iota_prior = SinePrior(parameter_names=["iota"]) |
| 94 | + |
| 95 | +prior = prior + [ |
| 96 | + s1_prior, |
| 97 | + s2_prior, |
| 98 | + iota_prior, |
| 99 | +] |
| 100 | + |
| 101 | +# Extrinsic prior |
| 102 | +dL_prior = PowerLawPrior(1.0, 2000.0, 2.0, parameter_names=["d_L"]) |
| 103 | +t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) |
| 104 | +phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) |
| 105 | +psi_prior = UniformPrior(0.0, jnp.pi, parameter_names=["psi"]) |
| 106 | +ra_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"]) |
| 107 | +dec_prior = CosinePrior(parameter_names=["dec"]) |
| 108 | + |
| 109 | +prior = prior + [ |
| 110 | + dL_prior, |
| 111 | + t_c_prior, |
| 112 | + phase_c_prior, |
| 113 | + psi_prior, |
| 114 | + ra_prior, |
| 115 | + dec_prior, |
| 116 | +] |
| 117 | + |
| 118 | +prior = CombinePrior(prior) |
| 119 | + |
| 120 | +# Defining Transforms |
| 121 | + |
| 122 | +sample_transforms = [ |
| 123 | + DistanceToSNRWeightedDistanceTransform( |
| 124 | + gps_time=gps, ifos=ifos, dL_min=dL_prior.xmin, dL_max=dL_prior.xmax |
| 125 | + ), |
| 126 | + GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(gps_time=gps, ifo=ifos[0]), |
| 127 | + GeocentricArrivalTimeToDetectorArrivalTimeTransform( |
| 128 | + tc_min=t_c_prior.xmin, tc_max=t_c_prior.xmax, gps_time=gps, ifo=ifos[0] |
| 129 | + ), |
| 130 | + SkyFrameToDetectorFrameSkyPositionTransform(gps_time=gps, ifos=ifos), |
| 131 | + BoundToUnbound( |
| 132 | + name_mapping=(["M_c"], ["M_c_unbounded"]), |
| 133 | + original_lower_bound=M_c_min, |
| 134 | + original_upper_bound=M_c_max, |
| 135 | + ), |
| 136 | + BoundToUnbound( |
| 137 | + name_mapping=(["q"], ["q_unbounded"]), |
| 138 | + original_lower_bound=q_min, |
| 139 | + original_upper_bound=q_max, |
| 140 | + ), |
| 141 | + BoundToUnbound( |
| 142 | + name_mapping=(["s1_phi"], ["s1_phi_unbounded"]), |
| 143 | + original_lower_bound=0.0, |
| 144 | + original_upper_bound=2 * jnp.pi, |
| 145 | + ), |
| 146 | + BoundToUnbound( |
| 147 | + name_mapping=(["s2_phi"], ["s2_phi_unbounded"]), |
| 148 | + original_lower_bound=0.0, |
| 149 | + original_upper_bound=2 * jnp.pi, |
| 150 | + ), |
| 151 | + BoundToUnbound( |
| 152 | + name_mapping=(["iota"], ["iota_unbounded"]), |
| 153 | + original_lower_bound=0.0, |
| 154 | + original_upper_bound=jnp.pi, |
| 155 | + ), |
| 156 | + BoundToUnbound( |
| 157 | + name_mapping=(["s1_theta"], ["s1_theta_unbounded"]), |
| 158 | + original_lower_bound=0.0, |
| 159 | + original_upper_bound=jnp.pi, |
| 160 | + ), |
| 161 | + BoundToUnbound( |
| 162 | + name_mapping=(["s2_theta"], ["s2_theta_unbounded"]), |
| 163 | + original_lower_bound=0.0, |
| 164 | + original_upper_bound=jnp.pi, |
| 165 | + ), |
| 166 | + BoundToUnbound( |
| 167 | + name_mapping=(["s1_mag"], ["s1_mag_unbounded"]), |
| 168 | + original_lower_bound=0.0, |
| 169 | + original_upper_bound=0.99, |
| 170 | + ), |
| 171 | + BoundToUnbound( |
| 172 | + name_mapping=(["s2_mag"], ["s2_mag_unbounded"]), |
| 173 | + original_lower_bound=0.0, |
| 174 | + original_upper_bound=0.99, |
| 175 | + ), |
| 176 | + BoundToUnbound( |
| 177 | + name_mapping=(["phase_det"], ["phase_det_unbounded"]), |
| 178 | + original_lower_bound=0.0, |
| 179 | + original_upper_bound=2 * jnp.pi, |
| 180 | + ), |
| 181 | + BoundToUnbound( |
| 182 | + name_mapping=(["psi"], ["psi_unbounded"]), |
| 183 | + original_lower_bound=0.0, |
| 184 | + original_upper_bound=jnp.pi, |
| 185 | + ), |
| 186 | + BoundToUnbound( |
| 187 | + name_mapping=(["zenith"], ["zenith_unbounded"]), |
| 188 | + original_lower_bound=0.0, |
| 189 | + original_upper_bound=jnp.pi, |
| 190 | + ), |
| 191 | + BoundToUnbound( |
| 192 | + name_mapping=(["azimuth"], ["azimuth_unbounded"]), |
| 193 | + original_lower_bound=0.0, |
| 194 | + original_upper_bound=2 * jnp.pi, |
| 195 | + ), |
| 196 | +] |
| 197 | + |
| 198 | +likelihood_transforms = [ |
| 199 | + MassRatioToSymmetricMassRatioTransform, |
| 200 | + SphereSpinToCartesianSpinTransform("s1"), |
| 201 | + SphereSpinToCartesianSpinTransform("s2"), |
| 202 | +] |
| 203 | + |
| 204 | + |
| 205 | +likelihood = BaseTransientLikelihoodFD( |
| 206 | + ifos, |
| 207 | + waveform=waveform, |
| 208 | + trigger_time=gps, |
| 209 | + f_min=fmin, |
| 210 | + f_max=fmax, |
| 211 | +) |
| 212 | + |
| 213 | +jim = Jim( |
| 214 | + likelihood, |
| 215 | + prior, |
| 216 | + sample_transforms=sample_transforms, |
| 217 | + likelihood_transforms=likelihood_transforms, |
| 218 | + n_chains=500, |
| 219 | + n_local_steps=100, |
| 220 | + n_global_steps=1000, |
| 221 | + n_training_loops=20, |
| 222 | + n_production_loops=10, |
| 223 | + n_epochs=20, |
| 224 | + mala_step_size=2e-3, |
| 225 | + rq_spline_hidden_units=[128, 128], |
| 226 | + rq_spline_n_bins=10, |
| 227 | + rq_spline_n_layers=8, |
| 228 | + learning_rate=1e-3, |
| 229 | + batch_size=10000, |
| 230 | + n_max_examples=30000, |
| 231 | + n_NFproposal_batch_size=100, |
| 232 | + local_thinning=1, |
| 233 | + global_thinning=100, |
| 234 | + history_window=200, |
| 235 | + n_temperatures=0, |
| 236 | + max_temperature=20.0, |
| 237 | + n_tempered_steps=10, |
| 238 | + verbose=True, |
| 239 | +) |
| 240 | +# |
| 241 | +jim.sample() |
| 242 | + |
| 243 | +print("Done!") |
| 244 | + |
| 245 | +logprob = jim.sampler.resources["log_prob_production"].data |
| 246 | +print(jnp.mean(logprob)) |
| 247 | + |
| 248 | +chains = jim.get_samples() |
| 249 | + |
| 250 | +import numpy as np |
| 251 | +import corner |
| 252 | + |
| 253 | +fig = corner.corner( |
| 254 | + np.stack([chains[key] for key in jim.prior.parameter_names]).T[::10] |
| 255 | +) |
| 256 | +fig.savefig("test") |
0 commit comments