Skip to content

Commit 839bd03

Browse files
Merge pull request #257 from vijayvarma392/add_tests
add more tests
2 parents 785f245 + 9f781b2 commit 839bd03

12 files changed

Lines changed: 705 additions & 37 deletions

.github/workflows/test.yml

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ jobs:
1919
fail-fast: false
2020
matrix:
2121
os: [macos-latest, ubuntu-latest]
22-
python-version: ['3.9']
22+
python-version: ['3.9', '3.10']
2323

2424
steps:
2525
- name: Check out repository code
@@ -98,6 +98,18 @@ jobs:
9898
run: |
9999
pip install .
100100
101+
- name: Download SXS NR data for precessing tests
102+
if: matrix.python-version != '3.9'
103+
shell: bash -l {0}
104+
env:
105+
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
106+
run: |
107+
python - <<'EOF'
108+
from gw_eccentricity.load_data import download_sxs_waveform
109+
download_sxs_waveform("SXS:BBH:2859", 4, "/tmp/gwecc_test_data",
110+
catalog_tag="v3.0.0")
111+
EOF
112+
101113
- name: Run test suite
102114
shell: bash -l {0}
103115
run: |

gw_eccentricity/eccDefinition.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2379,6 +2379,21 @@ def measure_ecc(self):
23792379
self.tmax = min(self.t_pericenters[-1], self.t_apocenters[-1])
23802380
self.tmin = max(self.t_pericenters[0], self.t_apocenters[0])
23812381

2382+
# t_for_checks depends only on tmin/tmax (already set above).
2383+
# Build it early so get_omega_gw_extrema_interpolant can use it when
2384+
# converting fref_in → tref_in via mean_of_extrema_interpolants.
2385+
self.t_for_checks = self.dataDict["t"][
2386+
np.logical_and(self.dataDict["t"] >= self.tmin,
2387+
self.dataDict["t"] <= self.tmax)]
2388+
2389+
# Build omega_gw extrema interpolants before the fref_in conversion so
2390+
# that compute_mean_of_extrema_interpolants is available as an
2391+
# averaging method in compute_tref_in_and_fref_out_from_fref_in.
2392+
self.omega_gw_pericenters_interp \
2393+
= self.get_omega_gw_extrema_interpolant("pericenters")
2394+
self.omega_gw_apocenters_interp \
2395+
= self.get_omega_gw_extrema_interpolant("apocenters")
2396+
23822397
if self.domain == "frequency":
23832398
# get the tref_in and fref_out from fref_in
23842399
self.tref_in, self.fref_out \
@@ -2387,10 +2402,6 @@ def measure_ecc(self):
23872402
self.tref_out = self.tref_in[
23882403
np.logical_and(self.tref_in <= self.tmax,
23892404
self.tref_in >= self.tmin)]
2390-
# set time for checks and diagnostics
2391-
self.t_for_checks = self.dataDict["t"][
2392-
np.logical_and(self.dataDict["t"] >= self.tmin,
2393-
self.dataDict["t"] <= self.tmax)]
23942405

23952406
# Sanity checks
23962407
# check that fref_out and tref_out are of the same length
@@ -2427,11 +2438,6 @@ def measure_ecc(self):
24272438
or self.tref_out[-1] > self.t_pericenters[-1]:
24282439
raise Exception("Reference time must be within two pericenters.")
24292440

2430-
# Build omega_gw extrema interpolants
2431-
self.omega_gw_pericenters_interp \
2432-
= self.get_omega_gw_extrema_interpolant("pericenters")
2433-
self.omega_gw_apocenters_interp \
2434-
= self.get_omega_gw_extrema_interpolant("apocenters")
24352441
# compute eccentricity at self.tref_out
24362442
self.eccentricity = self.compute_eccentricity(self.tref_out)
24372443
# Compute mean anomaly at tref_out

gw_eccentricity/load_data.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -743,6 +743,87 @@ def load_lvcnr_waveform(**kwargs):
743743
return return_dict
744744

745745

746+
def download_sxs_waveform(sxs_id, lev, data_dir, extrap_order=2,
747+
catalog_tag=""):
748+
"""Download SXS catalog waveform files to a local directory.
749+
750+
Downloads the files required by ``load_waveform(origin="SXSCatalog")``:
751+
752+
- ``Strain_N{extrap_order}.h5``
753+
- ``Strain_N{extrap_order}.json``
754+
- ``metadata.json``
755+
- ``Horizons.h5``
756+
757+
Parameters
758+
----------
759+
sxs_id : str
760+
SXS simulation identifier, e.g. ``"SXS:BBH:3726"``.
761+
lev : int
762+
Resolution level, e.g. ``3``.
763+
data_dir : str
764+
Root directory. Files are placed inside
765+
``<data_dir>/<sxs_id_safe>/Lev<lev>/``, where ``sxs_id_safe``
766+
is ``sxs_id`` with ``":"`` replaced by ``"_"``
767+
(e.g. ``"SXS_BBH_3726"``).
768+
extrap_order : int, optional
769+
Waveform extrapolation order. Default is 2, which selects the
770+
``Strain_N2.h5`` / ``Strain_N2.json`` file pair.
771+
catalog_tag : str, optional
772+
Git tag of the SXS catalog release to use, e.g. ``"v3.0.0"``.
773+
When provided the catalog is fetched directly from that tagged
774+
release, bypassing the GitHub API call that discovers the latest
775+
release. This avoids GitHub API rate-limit errors in CI.
776+
Defaults to ``""`` which fetches the latest release.
777+
778+
Returns
779+
-------
780+
str
781+
Full path to the directory that now contains the downloaded
782+
files. Pass this as ``data_dir`` to
783+
``load_waveform(origin="SXSCatalog", data_dir=...)``.
784+
785+
Notes
786+
-----
787+
Files that already exist on disk are skipped, that is not re-downloaded.
788+
The SXS simulations catalogue is loaded via the ``sxs`` Python package to
789+
retrieve direct download URLs for each file.
790+
"""
791+
from sxs.utilities import download_file
792+
793+
# Build target directory
794+
sxs_id_safe = sxs_id.replace(":", "_")
795+
target_dir = os.path.join(
796+
os.path.expanduser(data_dir), sxs_id_safe, f"Lev{lev}")
797+
os.makedirs(target_dir, exist_ok=True)
798+
799+
# Load the simulations catalogue to resolve file URLs.
800+
# Pass `tag` when provided to skip the GitHub API call that discovers
801+
# the latest release (avoids rate-limit errors in CI environments).
802+
simulations = sxs.Simulations.load(download=True, tag=catalog_tag)
803+
if sxs_id not in simulations:
804+
raise ValueError(
805+
f"'{sxs_id}' not found in the SXS simulations catalogue. "
806+
"Check the SXS ID and that the catalogue is up to date.")
807+
808+
sim_files = simulations[sxs_id].get("files", {})
809+
810+
filenames = [
811+
f"Strain_N{extrap_order}.h5",
812+
f"Strain_N{extrap_order}.json",
813+
"metadata.json",
814+
"Horizons.h5",
815+
]
816+
817+
for filename in filenames:
818+
dest = os.path.join(target_dir, filename)
819+
if os.path.exists(dest):
820+
continue # already present, skip
821+
url = sim_files[f"Lev{lev}:{filename}"]["link"]
822+
download_file(url, dest, progress=False)
823+
824+
return target_dir
825+
826+
746827
def load_sxs_catalogformat(**kwargs):
747828
"""Load modes from sxs waveform files in sxs catalog format.
748829

gw_eccentricity/rational_fit.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,8 @@ def _eval_arnoldi_basis(x, H, degree, v0_norm):
104104
# Build higher degree basis functions using the stored Hessenberg
105105
# matrix H
106106
for k in range(degree):
107-
v = V[:, k] * x - V[:, :k+1] @ H[:k+1, k]
107+
with np.errstate(divide='ignore', over='ignore', invalid='ignore'):
108+
v = V[:, k] * x - V[:, :k+1] @ H[:k+1, k]
108109
if H[k + 1, k] < 1e-14:
109110
break
110111
V[:, k + 1] = v / H[k + 1, k]
@@ -405,8 +406,9 @@ def predict(self, x_new):
405406

406407
# Compute the numerator and denominator at x_new using the
407408
# fitted coefficients
408-
p = P @ self.a
409-
q = Q @ self.b
409+
with np.errstate(divide='ignore', over='ignore', invalid='ignore'):
410+
p = P @ self.a
411+
q = Q @ self.b
410412

411413
# Return the rational function values at x_new
412414
r = p / q

pytest.ini

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,5 @@
1-
[pytest]
1+
[pytest]
2+
filterwarnings =
3+
ignore::DeprecationWarning:pyseobnr
4+
ignore::DeprecationWarning:scipy
5+
ignore::DeprecationWarning:qnm

test/conftest.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
"""Shared test helpers and fixtures for the gw_eccentricity test suite."""
2+
import numpy as np
3+
from pyseobnr.generate_waveform import generate_modes_opt
4+
5+
6+
def get_seob_datadict(q=5.0, chi1z=0.4, chi2z=0.3,
7+
omega_start=0.02, ecc=0.1, mean_ano=1.7,
8+
include_zero_ecc=False):
9+
"""Return a dataDict built from a SEOBNRv5EHM waveform.
10+
11+
Parameters
12+
----------
13+
q : mass ratio
14+
chi1z, chi2z : aligned spin components
15+
omega_start : initial orbital frequency in geometric units
16+
ecc : initial eccentricity
17+
mean_ano : initial mean anomaly
18+
include_zero_ecc : if True, also add t_zeroecc/hlm_zeroecc keys built
19+
from a zero-eccentricity waveform starting slightly
20+
earlier (omega_start * 0.9), required by Residual* methods
21+
"""
22+
t, modes = generate_modes_opt(
23+
q, chi1z, chi2z, omega_start,
24+
eccentricity=ecc, rel_anomaly=mean_ano,
25+
approximant="SEOBNRv5EHM")
26+
hlm = {tuple(int(x) for x in k.split(",")): v for k, v in modes.items()}
27+
dataDict = {"t": t, "hlm": hlm}
28+
29+
if include_zero_ecc:
30+
t_z, modes_z = generate_modes_opt(
31+
q, chi1z, chi2z, omega_start * 0.9,
32+
eccentricity=0, rel_anomaly=0,
33+
approximant="SEOBNRv5EHM")
34+
hlm_z = {tuple(int(x) for x in k.split(",")): v for k, v in modes_z.items()}
35+
dataDict.update({"t_zeroecc": t_z, "hlm_zeroecc": hlm_z})
36+
37+
return dataDict

test/test_mks_vs_dimless_units.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,13 @@
55
from lal import MTSUN_SI
66

77

8+
# rational_fit introduces ~2x more floating-point noise than spline due to its
9+
# iterative QR algorithm, pushing unit-conversion differences to ~2e-7.
10+
# rtol=1e-6 comfortably covers both methods while still catching real unit bugs.
11+
# (numpy assert_allclose default is 1e-7)
12+
UNIT_CONSISTENCY_RTOL = 1e-6
13+
14+
815
def test_mks_vs_dimless_units():
916
""" Tests that the measure_eccentricity interface is working for both
1017
MKS and dimensionless units.
@@ -38,9 +45,7 @@ def test_mks_vs_dimless_units():
3845
"D": 1})
3946
dataDictMKS = load_data.load_waveform(**lal_kwargs)
4047

41-
# use different omega_gw_extrema_interpolation methods
42-
# TODO: Add rational_fit to the list
43-
omega_gw_extrema_interpolation_methods = ["spline"]
48+
omega_gw_extrema_interpolation_methods = ["spline", "rational_fit"]
4449

4550
# List of all available methods
4651
available_methods = gw_eccentricity.get_available_methods()
@@ -74,19 +79,22 @@ def test_mks_vs_dimless_units():
7479
np.testing.assert_allclose(
7580
[tref_out],
7681
[tref_out_MKS * sec_to_dimless],
82+
rtol=UNIT_CONSISTENCY_RTOL,
7783
err_msg=("tref_out at a single dimensionless and MKS"
7884
" time are inconsistent.\n"
7985
"x = Dimensionless, y = MKS converted to dimless"))
8086
# Check if the measured ecc an mean ano are the same from the two units
8187
np.testing.assert_allclose(
8288
[ecc_ref],
8389
[ecc_ref_MKS],
90+
rtol=UNIT_CONSISTENCY_RTOL,
8491
err_msg=("Eccentricity at a single dimensionless and MKS"
8592
" time gives different results.\n"
8693
"x = Dimensionless, y = MKS"))
8794
np.testing.assert_allclose(
8895
[meanano_ref],
8996
[meanano_ref_MKS],
97+
rtol=UNIT_CONSISTENCY_RTOL,
9098
err_msg=("Mean anomaly at a single dimensionless and MKS"
9199
" time gives different results.\n"
92100
"x = Dimensionless, y = MKS"))
@@ -115,13 +123,15 @@ def test_mks_vs_dimless_units():
115123
np.testing.assert_allclose(
116124
[tref_out],
117125
[tref_out_MKS * sec_to_dimless],
126+
rtol=UNIT_CONSISTENCY_RTOL,
118127
err_msg=("tref_out array for dimensionless and MKS"
119128
" tref_in are inconsistent.\n"
120129
"x = Dimensionless, y = MKS converted to dimless"))
121130
# Check if the measured ecc an mean ano are the same from the two units
122131
np.testing.assert_allclose(
123132
ecc_ref,
124133
ecc_ref_MKS,
134+
rtol=UNIT_CONSISTENCY_RTOL,
125135
err_msg=("Eccentricity at dimensionless and MKS array of"
126136
" times are different\n."
127137
"x = Dimensionless, y = MKS"))
@@ -130,6 +140,7 @@ def test_mks_vs_dimless_units():
130140
np.testing.assert_allclose(
131141
np.unwrap(meanano_ref),
132142
np.unwrap(meanano_ref_MKS),
143+
rtol=UNIT_CONSISTENCY_RTOL,
133144
err_msg=("Mean anomaly at dimensionless and MKS array of"
134145
" times are different.\n"
135146
"x = Dimensionless, y = MKS"))
@@ -160,19 +171,22 @@ def test_mks_vs_dimless_units():
160171
np.testing.assert_allclose(
161172
[fref_out],
162173
[fref_out_MKS / sec_to_dimless],
174+
rtol=UNIT_CONSISTENCY_RTOL,
163175
err_msg=("fref_out for a single dimensionless and MKS"
164176
" fref_in are inconsistent.\n"
165177
"x = Dimensionless, y = MKS converted to dimless"))
166178
# Check if the measured ecc an mean ano are the same from the two units
167179
np.testing.assert_allclose(
168180
[ecc_ref],
169181
[ecc_ref_MKS],
182+
rtol=UNIT_CONSISTENCY_RTOL,
170183
err_msg=("Eccentricity at a single dimensionless and MKS"
171184
" frequency gives different results.\n"
172185
"x = Dimensionless, y = MKS"))
173186
np.testing.assert_allclose(
174187
[meanano_ref],
175188
[meanano_ref_MKS],
189+
rtol=UNIT_CONSISTENCY_RTOL,
176190
err_msg=("Mean anomaly at a single dimensionless and MKS"
177191
" frequency gives different results.\n"
178192
"x = Dimensionless, y = MKS"))
@@ -203,19 +217,22 @@ def test_mks_vs_dimless_units():
203217
np.testing.assert_allclose(
204218
[fref_out],
205219
[fref_out_MKS / sec_to_dimless],
220+
rtol=UNIT_CONSISTENCY_RTOL,
206221
err_msg=("fref_out for an array of dimensionless and MKS"
207222
" fref_in are inconsistent.\n"
208223
"x = Dimensionless, y = MKS converted to dimless"))
209224
# Check if the measured ecc an mean ano are the same from the two units
210225
np.testing.assert_allclose(
211226
ecc_ref,
212227
ecc_ref_MKS,
228+
rtol=UNIT_CONSISTENCY_RTOL,
213229
err_msg=("Eccentricity at dimensionless and MKS array of"
214230
" frequencies are different.\n"
215231
"x = Dimensionless, y = MKS"))
216232
np.testing.assert_allclose(
217233
np.unwrap(meanano_ref),
218234
np.unwrap(meanano_ref_MKS),
235+
rtol=UNIT_CONSISTENCY_RTOL,
219236
err_msg=("Mean anomaly at dimensionless and MKS array of"
220237
" frequencies are different.\n"
221238
"x = Dimensionless, y = MKS"))

0 commit comments

Comments
 (0)