Skip to content

Commit 68baa6b

Browse files
committed
Update waveform.py
1 parent caa5f05 commit 68baa6b

1 file changed

Lines changed: 6 additions & 6 deletions

File tree

src/jimgw/core/single_event/waveform.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -203,8 +203,8 @@ def __repr__(self):
203203
class JaxNRSurHyb3dq8(Waveform):
204204
_waveform: FDWaveform.NRSurHyb3dq8_FD.value
205205

206-
def __init__(self, segment_length: float, sampling_rate: int = 4096, alpha_window: float = 0.1):
207-
self._waveform = FDWaveform.NRSurHyb3dq8_FD.value(segment_length=segment_length, sampling_rate=sampling_rate, alpha_window=alpha_window)
206+
def __init__(self, target_frequency: Float[Array, " n_sample"], segment_length: float, sampling_rate: int = 4096, alpha_window: float = 0.1):
207+
self._waveform = FDWaveform.NRSurHyb3dq8_FD.value(target_frequency, segment_length=segment_length, sampling_rate=sampling_rate, alpha_window=alpha_window)
208208

209209
def __call__(
210210
self, frequency: Float[Array, " n_dim"], params: dict[str, Float]
@@ -231,14 +231,14 @@ def __call__(
231231
return output
232232

233233
def __repr__(self):
234-
return f"JaxNRSurHyb3dq8(segment_length={self._waveform.segment_length}, sampling_rate={self._waveform.sampling_rate})"
234+
return f"JaxNRSurHyb3dq8(segment_length={self._waveform.surrogate.segment_length}, sampling_rate={self._waveform.surrogate.sampling_rate})"
235235

236236

237237
class JaxNRSur7dq4(Waveform):
238238
_waveform: FDWaveform.NRSur7dq4_FD.value
239239

240-
def __init__(self, segment_length: float, sampling_rate: int = 4096, alpha_window: float = 0.1):
241-
self._waveform = FDWaveform.NRSur7dq4_FD.value(segment_length=segment_length, sampling_rate=sampling_rate, alpha_window=alpha_window)
240+
def __init__(self, target_frequency: Float[Array, " n_sample"], segment_length: float, sampling_rate: int = 4096, alpha_window: float = 0.1):
241+
self._waveform = FDWaveform.NRSur7dq4_FD.value(target_frequency, segment_length=segment_length, sampling_rate=sampling_rate, alpha_window=alpha_window)
242242

243243
def __call__(
244244
self, frequency: Float[Array, " n_dim"], params: dict[str, Float]
@@ -270,7 +270,7 @@ def __call__(
270270
return output
271271

272272
def __repr__(self):
273-
return f"JaxNRSur7dq4(segment_length={self._waveform.segment_length}, sampling_rate={self._waveform.sampling_rate})"
273+
return f"JaxNRSur7dq4(segment_length={self._waveform.surrogate.segment_length}, sampling_rate={self._waveform.surrogate.sampling_rate})"
274274

275275

276276

0 commit comments

Comments
 (0)