@@ -787,3 +787,107 @@ def test_low_pass_filter(alpha):
787787 f"The filtered value at index { i } is not the expected value. "
788788 f"Expected: { expected } , Actual: { filtered_func .source [i ][1 ]} "
789789 )
790+
791+
792+ def test_average_function_ndarray ():
793+
794+ dummy_function = Function (
795+ source = [
796+ [0 , 0 ],
797+ [1 , 1 ],
798+ [2 , 0 ],
799+ [3 , 1 ],
800+ [4 , 0 ],
801+ [5 , 1 ],
802+ [6 , 0 ],
803+ [7 , 1 ],
804+ [8 , 0 ],
805+ [9 , 1 ],
806+ ],
807+ inputs = ["x" ],
808+ outputs = ["y" ],
809+ )
810+ avg_function = dummy_function .average_function ()
811+
812+ assert isinstance (avg_function , Function )
813+ assert np .isclose (avg_function (0 ), 0 )
814+ assert np .isclose (avg_function (9 ), 0.5 )
815+
816+
817+ def test_average_function_callable ():
818+
819+ dummy_function = Function (lambda x : 2 )
820+ avg_function = dummy_function .average_function (lower = 0 )
821+
822+ assert isinstance (avg_function , Function )
823+ assert np .isclose (avg_function (1 ), 2 )
824+ assert np .isclose (avg_function (9 ), 2 )
825+
826+
827+ @pytest .mark .parametrize (
828+ "lower, upper, sampling_frequency, window_size, step_size, remove_dc, only_positive" ,
829+ [
830+ (0 , 10 , 100 , 1 , 0.5 , True , True ),
831+ (0 , 10 , 100 , 1 , 0.5 , True , False ),
832+ (0 , 10 , 100 , 1 , 0.5 , False , True ),
833+ (0 , 10 , 100 , 1 , 0.5 , False , False ),
834+ (0 , 20 , 200 , 2 , 1 , True , True ),
835+ ],
836+ )
837+ def test_short_time_fft (
838+ lower , upper , sampling_frequency , window_size , step_size , remove_dc , only_positive
839+ ):
840+ """Test the short_time_fft method of the Function class.
841+
842+ Parameters
843+ ----------
844+ lower : float
845+ Lower bound of the time range.
846+ upper : float
847+ Upper bound of the time range.
848+ sampling_frequency : float
849+ Sampling frequency at which to perform the Fourier transform.
850+ window_size : float
851+ Size of the window for the STFT, in seconds.
852+ step_size : float
853+ Step size for the window, in seconds.
854+ remove_dc : bool
855+ If True, the DC component is removed from each window before
856+ computing the Fourier transform.
857+ only_positive: bool
858+ If True, only the positive frequencies are returned.
859+ """
860+ # Generate a test signal
861+ t = np .linspace (lower , upper , int ((upper - lower ) * sampling_frequency ))
862+ signal = np .sin (2 * np .pi * 5 * t ) # 5 Hz sine wave
863+ func = Function (np .column_stack ((t , signal )))
864+
865+ # Perform STFT
866+ stft_results = func .short_time_fft (
867+ lower = lower ,
868+ upper = upper ,
869+ sampling_frequency = sampling_frequency ,
870+ window_size = window_size ,
871+ step_size = step_size ,
872+ remove_dc = remove_dc ,
873+ only_positive = only_positive ,
874+ )
875+
876+ # Check the results
877+ assert isinstance (stft_results , list )
878+ assert all (isinstance (f , Function ) for f in stft_results )
879+
880+ for f in stft_results :
881+ assert f .get_inputs () == ["Frequency (Hz)" ]
882+ assert f .get_outputs () == ["Amplitude" ]
883+ assert f .get_interpolation_method () == "linear"
884+ assert f .get_extrapolation_method () == "zero"
885+
886+ frequencies = f .source [:, 0 ]
887+ # amplitudes = f.source[:, 1]
888+
889+ if only_positive :
890+ assert np .all (frequencies >= 0 )
891+ else :
892+ assert np .all (frequencies >= - sampling_frequency / 2 )
893+ assert np .all (frequencies <= sampling_frequency / 2 )
0 commit comments