@@ -203,8 +203,8 @@ def __repr__(self):
203203class JaxNRSurHyb3dq8 (Waveform ):
204204 _waveform : FDWaveform .NRSurHyb3dq8_FD .value
205205
206- def __init__ (self , segment_length : float , sampling_rate : int = 4096 , alpha_window : float = 0.1 ):
207- self ._waveform = FDWaveform .NRSurHyb3dq8_FD .value (segment_length = segment_length , sampling_rate = sampling_rate , alpha_window = alpha_window )
206+ def __init__ (self , target_frequency : Float [ Array , " n_sample" ], segment_length : float , sampling_rate : int = 4096 , alpha_window : float = 0.1 ):
207+ self ._waveform = FDWaveform .NRSurHyb3dq8_FD .value (target_frequency , segment_length = segment_length , sampling_rate = sampling_rate , alpha_window = alpha_window )
208208
209209 def __call__ (
210210 self , frequency : Float [Array , " n_dim" ], params : dict [str , Float ]
@@ -231,14 +231,14 @@ def __call__(
231231 return output
232232
233233 def __repr__ (self ):
234- return f"JaxNRSurHyb3dq8(segment_length={ self ._waveform .segment_length } , sampling_rate={ self ._waveform .sampling_rate } )"
234+ return f"JaxNRSurHyb3dq8(segment_length={ self ._waveform .surrogate . segment_length } , sampling_rate={ self ._waveform . surrogate .sampling_rate } )"
235235
236236
237237class JaxNRSur7dq4 (Waveform ):
238238 _waveform : FDWaveform .NRSur7dq4_FD .value
239239
240- def __init__ (self , segment_length : float , sampling_rate : int = 4096 , alpha_window : float = 0.1 ):
241- self ._waveform = FDWaveform .NRSur7dq4_FD .value (segment_length = segment_length , sampling_rate = sampling_rate , alpha_window = alpha_window )
240+ def __init__ (self , target_frequency : Float [ Array , " n_sample" ], segment_length : float , sampling_rate : int = 4096 , alpha_window : float = 0.1 ):
241+ self ._waveform = FDWaveform .NRSur7dq4_FD .value (target_frequency , segment_length = segment_length , sampling_rate = sampling_rate , alpha_window = alpha_window )
242242
243243 def __call__ (
244244 self , frequency : Float [Array , " n_dim" ], params : dict [str , Float ]
@@ -270,7 +270,7 @@ def __call__(
270270 return output
271271
272272 def __repr__ (self ):
273- return f"JaxNRSur7dq4(segment_length={ self ._waveform .segment_length } , sampling_rate={ self ._waveform .sampling_rate } )"
273+ return f"JaxNRSur7dq4(segment_length={ self ._waveform .surrogate . segment_length } , sampling_rate={ self ._waveform . surrogate .sampling_rate } )"
274274
275275
276276
0 commit comments