Skip to content

Commit f29fdf8

Browse files
rename and add dataclass
1 parent 3d3d048 commit f29fdf8

10 files changed

Lines changed: 1825 additions & 3656 deletions

File tree

gw_eccentricity/posterior/examples/post_process_for_bilby.ipynb

Lines changed: 0 additions & 3425 deletions
This file was deleted.
Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
from dataclasses import dataclass
2+
import logging
3+
import pandas as pd
4+
from ..gw_eccentricity import measure_eccentricity
5+
6+
logger = logging.getLogger(__name__)
7+
8+
@dataclass
9+
class PostProcessResult:
10+
sample_index: int
11+
status: str
12+
egw: float | None
13+
lgw: float | None
14+
error_message: str | None = None
15+
16+
@dataclass
17+
class PostProcessResults:
18+
results: list[PostProcessResult]
19+
20+
def to_dataframe(self) -> pd.DataFrame:
21+
"""Convert the list of PostProcessResult to a pandas DataFrame."""
22+
return pd.DataFrame([result.__dict__ for result in self.results])
23+
24+
def success_only(self) -> list[PostProcessResult]:
25+
"""Filter successful results."""
26+
return [r for r in self.results if r.status == "success"]
27+
28+
def get_summary(self) -> dict:
29+
"""Get summary statistics."""
30+
successful = self.success_only()
31+
total = len(self.results)
32+
return {
33+
'total_samples': total,
34+
'success_percentage': (len(successful) / total) * 100,
35+
'egw': [r.egw for r in successful],
36+
'lgw': [r.lgw for r in successful]
37+
}
38+
39+
@dataclass
40+
class FrefBoundsResult:
41+
sample_index: int
42+
status: str
43+
fref_min: float | None
44+
fref_max: float | None
45+
error_message: str | None = None
46+
47+
48+
@dataclass
49+
class FrefBoundsResults:
50+
results: list[FrefBoundsResult]
51+
52+
def to_dataframe(self) -> pd.DataFrame:
53+
"""Convert the list of FrefBoundsResult to a pandas DataFrame."""
54+
return pd.DataFrame([result.__dict__ for result in self.results])
55+
56+
def success_only(self) -> list[FrefBoundsResult]:
57+
"""Filter successful results."""
58+
return [r for r in self.results if r.status == "success"]
59+
60+
def get_summary(self) -> dict:
61+
"""Get summary statistics."""
62+
successful = self.success_only()
63+
total = len(self.results)
64+
return {
65+
'total_samples': total,
66+
'success_percentage': (len(successful) / total) * 100,
67+
'fref_min': [r.fref_min for r in successful],
68+
'fref_max': [r.fref_max for r in successful]
69+
}
70+
71+
def get_minmax_fref(self) -> tuple[float, float] | None:
72+
"""Get the min and max fref across all successful samples.
73+
74+
This provides the common range of fref values where eccentricity can be
75+
measured for all the successful samples.
76+
"""
77+
summary = self.get_summary()
78+
if summary['success_percentage'] == 0:
79+
raise Exception("No successful samples to determine fref bounds.")
80+
return max(summary['fref_min']), min(summary['fref_max'])
81+
82+
83+
def get_data_dict(
84+
params: dict,
85+
data_dict_generator: callable,
86+
extra_kwargs: dict | None = None
87+
) -> dict:
88+
"""Get data_dict for given params in the posterior.
89+
90+
Parameters
91+
----------
92+
params : dict
93+
Dictionary containing the parameters for the sample.
94+
95+
data_dict_generator : function
96+
data_dict is generated using function call as below::
97+
98+
data_dict = data_dict_generator(params, extra_kwargs)
99+
100+
extra_kwargs : dict, optional
101+
Extra kwargs passed to ``data_dict_generator``.
102+
103+
Returns
104+
-------
105+
data_dict : dict
106+
Dictionary of waveform modes data compatible with
107+
``gw_eccentricity.measure_eccentricity``.
108+
"""
109+
if extra_kwargs is None:
110+
extra_kwargs = {}
111+
data_dict = data_dict_generator(
112+
params, extra_kwargs)
113+
if not isinstance(data_dict, dict):
114+
raise TypeError(
115+
f"The data_dict generator `{data_dict_generator}` should "
116+
f"return a dict and not a {type(data_dict)}")
117+
return data_dict
118+
119+
120+
def get_fref_bounds_for_sample(
121+
sample_index: int,
122+
params: dict,
123+
data_dict_generator: callable,
124+
data_dict_generator_extra_kwargs: dict | None = None,
125+
method: str = "Amplitude",
126+
gw_eccentricity_kwargs: dict | None = None
127+
) -> FrefBoundsResult:
128+
"""Get the min and max allowed fref for a given sample.
129+
130+
Parameters
131+
----------
132+
sample_index : int
133+
Index of the sample in the posterior.
134+
params : dict
135+
Dictionary containing the parameters for the sample.
136+
data_dict_generator : function
137+
Function to generate the data dictionary for the sample.
138+
data_dict_generator_extra_kwargs : dict, optional
139+
Extra kwargs passed to ``data_dict_generator``.
140+
method : str, default="Amplitude"
141+
Method to use in ``gw_eccentricity.measure_eccentricity``.
142+
gw_eccentricity_kwargs : dict, optional
143+
Extra kwargs passed to ``gw_eccentricity.measure_eccentricity``.
144+
145+
Returns
146+
-------
147+
FrefBoundsResult
148+
with keys ``sample_index``, ``status``, ``fref_min``, ``fref_max``,
149+
and on failure ``error_message``.
150+
"""
151+
if gw_eccentricity_kwargs is None:
152+
gw_eccentricity_kwargs = {}
153+
try:
154+
data_dict = get_data_dict(params, data_dict_generator, data_dict_generator_extra_kwargs)
155+
res = measure_eccentricity(
156+
dataDict=data_dict,
157+
tref_in=data_dict["t"], # pass the full time array to get the fref bounds for the entire waveform
158+
method=method,
159+
**gw_eccentricity_kwargs)
160+
gw_obj = res["gwecc_object"]
161+
fref_bounds = gw_obj.get_fref_bounds()
162+
return FrefBoundsResult(
163+
sample_index=sample_index,
164+
status="success",
165+
fref_min=fref_bounds[0],
166+
fref_max=fref_bounds[1]
167+
)
168+
except Exception as e:
169+
logger.warning(f"Sample {params} failed to get fref bounds: {e}")
170+
return FrefBoundsResult(
171+
sample_index=sample_index,
172+
status="fail",
173+
fref_min=None,
174+
fref_max=None,
175+
error_message=str(e))
176+
177+
def postprocess_sample(
178+
sample_index: int,
179+
params: dict,
180+
fref: float,
181+
data_dict_generator: callable,
182+
data_dict_generator_extra_kwargs: dict | None = None,
183+
method: str = "Amplitude",
184+
gw_eccentricity_kwargs: dict | None = None) -> PostProcessResult:
185+
"""Measure eccentricity and mean anomaly from waveform modes for a sample.
186+
187+
A wrapper around ``gw_eccentricity.measure_eccentricity`` to measure
188+
eccentricity from the waveform modes for a sample with given ``params``.
189+
190+
Parameters
191+
----------
192+
sample_index : int
193+
Index of the sample in the posterior.
194+
params : dict
195+
Dictionary containing the parameters for the sample.
196+
fref : float
197+
Reference frequency where eccentricity is to be measured.
198+
data_dict_generator : function
199+
data_dict is generated using function call as below::
200+
data_dict = data_dict_generator(params, data_dict_generator_extra_kwargs)
201+
data_dict_generator_extra_kwargs : dict, optional
202+
Extra kwargs passed to ``data_dict_generator``.
203+
method : str, default="Amplitude"
204+
Method to use in ``gw_eccentricity.measure_eccentricity``.
205+
gw_eccentricity_kwargs : dict, optional
206+
Extra kwargs passed to ``gw_eccentricity.measure_eccentricity``.
207+
208+
Returns
209+
-------
210+
PostProcessResult
211+
with keys: ``status``, ``egw``, ``lgw``, and on failure
212+
``error_message``.
213+
"""
214+
try:
215+
data_dict = get_data_dict(
216+
params,
217+
data_dict_generator,
218+
data_dict_generator_extra_kwargs)
219+
res = measure_eccentricity(
220+
dataDict=data_dict,
221+
fref_in=fref,
222+
method=method,
223+
**(gw_eccentricity_kwargs or {}))
224+
return PostProcessResult(
225+
sample_index=sample_index,
226+
status="success",
227+
egw=res["eccentricity"],
228+
lgw=res["mean_anomaly"]
229+
)
230+
except Exception as e:
231+
logger.warning(f"Sample {params} failed: {e}")
232+
return PostProcessResult(
233+
sample_index=sample_index,
234+
status="fail",
235+
egw=None,
236+
lgw=None,
237+
error_message=str(e))

gw_eccentricity/posterior/examples/postprocess.sh renamed to gw_eccentricity/postprocess/examples/postprocess.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@ CMD=(
99
--output-dir "/Users/arif/Desktop/"
1010
--output-format csv
1111
--save-every none
12-
--samples 0:1000
12+
--samples 0:100
1313
--fref 10
1414
--method AmplitudeFits
15-
--data-dict-generator "/Users/arif/gw_eccentricity/gw_eccentricity/posterior/examples/teob_backward_evolution.py:teob_data_dict_generator"
15+
--data-dict-generator "/Users/arif/gw_eccentricity/gw_eccentricity/postprocess/examples/teob_backward_evolution.py:teob_data_dict_generator"
1616
--data-dict-generator-extra-kwargs '{"backwards":"yes","ode_tmax":1}'
1717
--gw-eccentricity-kwargs '{}'
1818
)

0 commit comments

Comments
 (0)