Skip to content

Commit ee179e1

Browse files
committed
Improve handling of CoordRadii argument
1 parent 61a8b1a commit ee179e1

1 file changed

Lines changed: 39 additions & 8 deletions

File tree

scri/extrapolation.py

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import numbers
12
import numpy as np
23
from numpy.polynomial.polynomial import polyfit
34

@@ -307,7 +308,7 @@ def read_finite_radius_waveform_rpxmb_or_rpdmb(filename,
307308
sxs_format = json_data.get("sxs_format","")
308309
if sxs_format in rpdmb_formats:
309310
read_rpdmb=True
310-
elif sxs_format in rpdmb_formats:
311+
elif sxs_format in rpxmb_formats:
311312
read_rpxmb=True
312313

313314
# Note that groupname begins with a '/' so
@@ -429,6 +430,13 @@ def read_finite_radius_data(ChMass=0.0,
429430
from re import compile as re_compile
430431
import scri
431432

433+
def extract_radius_string_from_waveform_name(name):
434+
# Expected names like "R0247.dir"
435+
m = re_compile(r"""R(?P<r>.*?)\.dir""").search(str(name))
436+
if not m:
437+
raise ValueError(f"Could not parse radius from waveform name {name!r}")
438+
return m.group("r")
439+
432440
YLMRegex = re_compile(mode_regex)
433441

434442
# If 'filename' is of the form "h5_file_name.h5/groupname" then we have an
@@ -456,23 +464,45 @@ def read_finite_radius_data(ChMass=0.0,
456464
WaveformNames.remove("VersionHist.ver")
457465
else:
458466
WaveformNames = list(f[groupname])
467+
AllWaveformNames = list(WaveformNames)
459468
if not CoordRadii:
460469
# If the list of Radii is empty, figure out what they are
461470
CoordRadii = [
462471
m.group("r") for Name in WaveformNames for m in [re_compile(r"""R(?P<r>.*?)\.dir""").search(Name)] if m
463472
]
464473
else:
465474
# Pare down the WaveformNames list appropriately
466-
if type(CoordRadii[0]) == int:
475+
if isinstance(CoordRadii[0], numbers.Integral):
467476
WaveformNames = [WaveformNames[i] for i in CoordRadii]
468-
CoordRadii = [
469-
m.group("r") for Name in CoordRadii
470-
for m in
471-
[ re_compile(r"""R(?P<r>.*?)\.dir""").search(Name)] if m ]
477+
# Convert selected waveform names into radius strings, keeping the same order
478+
CoordRadii = [extract_radius_string_from_waveform_name(Name) for Name in WaveformNames]
472479
else:
473480
WaveformNames = [
474481
Name for Name in WaveformNames for Radius in
475482
CoordRadii for m in [re_compile(Radius).search(Name)] if m]
483+
# Best-effort: If user gave explicit radii strings, keep them; otherwise derive
484+
# parsed radii corresponding to selected names for consistent sorting later.
485+
try:
486+
CoordRadii = [extract_radius_string_from_waveform_name(Name) for Name in WaveformNames]
487+
except Exception:
488+
# Fall back to whatever the user provided (for backwards compatibility)
489+
pass
490+
491+
if not WaveformNames:
492+
raise ValueError(
493+
"No waveform groups matched the requested CoordRadii. "
494+
f"Requested CoordRadii={CoordRadii}; available groups={AllWaveformNames}."
495+
)
496+
if len(WaveformNames) != len(CoordRadii):
497+
raise ValueError(
498+
"Mismatch between requested radii and matched waveform groups. "
499+
f"Requested CoordRadii={CoordRadii}; matched groups={WaveformNames}."
500+
)
501+
try:
502+
[float(r) for r in CoordRadii]
503+
except Exception as e:
504+
raise ValueError(f"Could not convert CoordRadii entries to float: CoordRadii={CoordRadii}") from e
505+
476506
NWaveforms = len(WaveformNames)
477507

478508
# Check input data for NRAR format
@@ -727,7 +757,7 @@ def extrapolate(**kwargs):
727757
D['DataFile'] = {DataFile}
728758
D['ChMass'] = {ChMass}
729759
D['HorizonsFile'] = {HorizonsFile}
730-
D['CoordRadii'] = {CoordRadii}
760+
D['CoordRadii'] = {CoordRadiiKwarg}
731761
D['ExtrapolationOrders'] = {ExtrapolationOrders}
732762
D['UseOmega'] = {UseOmega}
733763
D['OutputFrame'] = {OutputFrame}
@@ -826,7 +856,8 @@ def extrapolate(**kwargs):
826856

827857
# Append the relevant information to the history
828858
ExtrapolatedWaveforms[i]._append_history(str(InputArguments))
829-
ExtrapolatedWaveforms[i].extrapolate_coord_radii = CoordRadiiKwarg
859+
# Record the actual radii used for extrapolation so downstream code has the normalized set
860+
ExtrapolatedWaveforms[i].extrapolate_coord_radii = list(CoordRadii)
830861

831862
# Output the data
832863
if OutputDirectory:

0 commit comments

Comments
 (0)