1+ from dataclasses import dataclass
2+ import logging
3+ import pandas as pd
4+ from ..gw_eccentricity import measure_eccentricity
5+
6+ logger = logging .getLogger (__name__ )
7+
8+ @dataclass
9+ class PostProcessResult :
10+ sample_index : int
11+ status : str
12+ egw : float | None
13+ lgw : float | None
14+ error_message : str | None = None
15+
16+ @dataclass
17+ class PostProcessResults :
18+ results : list [PostProcessResult ]
19+
20+ def to_dataframe (self ) -> pd .DataFrame :
21+ """Convert the list of PostProcessResult to a pandas DataFrame."""
22+ return pd .DataFrame ([result .__dict__ for result in self .results ])
23+
24+ def success_only (self ) -> list [PostProcessResult ]:
25+ """Filter successful results."""
26+ return [r for r in self .results if r .status == "success" ]
27+
28+ def get_summary (self ) -> dict :
29+ """Get summary statistics."""
30+ successful = self .success_only ()
31+ total = len (self .results )
32+ return {
33+ 'total_samples' : total ,
34+ 'success_percentage' : (len (successful ) / total ) * 100 ,
35+ 'egw' : [r .egw for r in successful ],
36+ 'lgw' : [r .lgw for r in successful ]
37+ }
38+
39+ @dataclass
40+ class FrefBoundsResult :
41+ sample_index : int
42+ status : str
43+ fref_min : float | None
44+ fref_max : float | None
45+ error_message : str | None = None
46+
47+
48+ @dataclass
49+ class FrefBoundsResults :
50+ results : list [FrefBoundsResult ]
51+
52+ def to_dataframe (self ) -> pd .DataFrame :
53+ """Convert the list of FrefBoundsResult to a pandas DataFrame."""
54+ return pd .DataFrame ([result .__dict__ for result in self .results ])
55+
56+ def success_only (self ) -> list [FrefBoundsResult ]:
57+ """Filter successful results."""
58+ return [r for r in self .results if r .status == "success" ]
59+
60+ def get_summary (self ) -> dict :
61+ """Get summary statistics."""
62+ successful = self .success_only ()
63+ total = len (self .results )
64+ return {
65+ 'total_samples' : total ,
66+ 'success_percentage' : (len (successful ) / total ) * 100 ,
67+ 'fref_min' : [r .fref_min for r in successful ],
68+ 'fref_max' : [r .fref_max for r in successful ]
69+ }
70+
71+ def get_minmax_fref (self ) -> tuple [float , float ] | None :
72+ """Get the min and max fref across all successful samples.
73+
74+ This provides the common range of fref values where eccentricity can be
75+ measured for all the successful samples.
76+ """
77+ summary = self .get_summary ()
78+ if summary ['success_percentage' ] == 0 :
79+ raise Exception ("No successful samples to determine fref bounds." )
80+ return max (summary ['fref_min' ]), min (summary ['fref_max' ])
81+
82+
83+ def get_data_dict (
84+ params : dict ,
85+ data_dict_generator : callable ,
86+ extra_kwargs : dict | None = None
87+ ) -> dict :
88+ """Get data_dict for given params in the posterior.
89+
90+ Parameters
91+ ----------
92+ params : dict
93+ Dictionary containing the parameters for the sample.
94+
95+ data_dict_generator : function
96+ data_dict is generated using function call as below::
97+
98+ data_dict = data_dict_generator(params, extra_kwargs)
99+
100+ extra_kwargs : dict, optional
101+ Extra kwargs passed to ``data_dict_generator``.
102+
103+ Returns
104+ -------
105+ data_dict : dict
106+ Dictionary of waveform modes data compatible with
107+ ``gw_eccentricity.measure_eccentricity``.
108+ """
109+ if extra_kwargs is None :
110+ extra_kwargs = {}
111+ data_dict = data_dict_generator (
112+ params , extra_kwargs )
113+ if not isinstance (data_dict , dict ):
114+ raise TypeError (
115+ f"The data_dict generator `{ data_dict_generator } ` should "
116+ f"return a dict and not a { type (data_dict )} " )
117+ return data_dict
118+
119+
120+ def get_fref_bounds_for_sample (
121+ sample_index : int ,
122+ params : dict ,
123+ data_dict_generator : callable ,
124+ data_dict_generator_extra_kwargs : dict | None = None ,
125+ method : str = "Amplitude" ,
126+ gw_eccentricity_kwargs : dict | None = None
127+ ) -> FrefBoundsResult :
128+ """Get the min and max allowed fref for a given sample.
129+
130+ Parameters
131+ ----------
132+ sample_index : int
133+ Index of the sample in the posterior.
134+ params : dict
135+ Dictionary containing the parameters for the sample.
136+ data_dict_generator : function
137+ Function to generate the data dictionary for the sample.
138+ data_dict_generator_extra_kwargs : dict, optional
139+ Extra kwargs passed to ``data_dict_generator``.
140+ method : str, default="Amplitude"
141+ Method to use in ``gw_eccentricity.measure_eccentricity``.
142+ gw_eccentricity_kwargs : dict, optional
143+ Extra kwargs passed to ``gw_eccentricity.measure_eccentricity``.
144+
145+ Returns
146+ -------
147+ FrefBoundsResult
148+ with keys ``sample_index``, ``status``, ``fref_min``, ``fref_max``,
149+ and on failure ``error_message``.
150+ """
151+ if gw_eccentricity_kwargs is None :
152+ gw_eccentricity_kwargs = {}
153+ try :
154+ data_dict = get_data_dict (params , data_dict_generator , data_dict_generator_extra_kwargs )
155+ res = measure_eccentricity (
156+ dataDict = data_dict ,
157+ tref_in = data_dict ["t" ], # pass the full time array to get the fref bounds for the entire waveform
158+ method = method ,
159+ ** gw_eccentricity_kwargs )
160+ gw_obj = res ["gwecc_object" ]
161+ fref_bounds = gw_obj .get_fref_bounds ()
162+ return FrefBoundsResult (
163+ sample_index = sample_index ,
164+ status = "success" ,
165+ fref_min = fref_bounds [0 ],
166+ fref_max = fref_bounds [1 ]
167+ )
168+ except Exception as e :
169+ logger .warning (f"Sample { params } failed to get fref bounds: { e } " )
170+ return FrefBoundsResult (
171+ sample_index = sample_index ,
172+ status = "fail" ,
173+ fref_min = None ,
174+ fref_max = None ,
175+ error_message = str (e ))
176+
177+ def postprocess_sample (
178+ sample_index : int ,
179+ params : dict ,
180+ fref : float ,
181+ data_dict_generator : callable ,
182+ data_dict_generator_extra_kwargs : dict | None = None ,
183+ method : str = "Amplitude" ,
184+ gw_eccentricity_kwargs : dict | None = None ) -> PostProcessResult :
185+ """Measure eccentricity and mean anomaly from waveform modes for a sample.
186+
187+ A wrapper around ``gw_eccentricity.measure_eccentricity`` to measure
188+ eccentricity from the waveform modes for a sample with given ``params``.
189+
190+ Parameters
191+ ----------
192+ sample_index : int
193+ Index of the sample in the posterior.
194+ params : dict
195+ Dictionary containing the parameters for the sample.
196+ fref : float
197+ Reference frequency where eccentricity is to be measured.
198+ data_dict_generator : function
199+ data_dict is generated using function call as below::
200+ data_dict = data_dict_generator(params, data_dict_generator_extra_kwargs)
201+ data_dict_generator_extra_kwargs : dict, optional
202+ Extra kwargs passed to ``data_dict_generator``.
203+ method : str, default="Amplitude"
204+ Method to use in ``gw_eccentricity.measure_eccentricity``.
205+ gw_eccentricity_kwargs : dict, optional
206+ Extra kwargs passed to ``gw_eccentricity.measure_eccentricity``.
207+
208+ Returns
209+ -------
210+ PostProcessResult
211+ with keys: ``status``, ``egw``, ``lgw``, and on failure
212+ ``error_message``.
213+ """
214+ try :
215+ data_dict = get_data_dict (
216+ params ,
217+ data_dict_generator ,
218+ data_dict_generator_extra_kwargs )
219+ res = measure_eccentricity (
220+ dataDict = data_dict ,
221+ fref_in = fref ,
222+ method = method ,
223+ ** (gw_eccentricity_kwargs or {}))
224+ return PostProcessResult (
225+ sample_index = sample_index ,
226+ status = "success" ,
227+ egw = res ["eccentricity" ],
228+ lgw = res ["mean_anomaly" ]
229+ )
230+ except Exception as e :
231+ logger .warning (f"Sample { params } failed: { e } " )
232+ return PostProcessResult (
233+ sample_index = sample_index ,
234+ status = "fail" ,
235+ egw = None ,
236+ lgw = None ,
237+ error_message = str (e ))
0 commit comments