@@ -53,30 +53,33 @@ def _pocketfft_oaconvolve_block(Q, T, conv_block_size):
5353 return c2r (False , np .multiply (fft_2d [:- 1 ], fft_2d [[- 1 ]]), n = conv_block_size )
5454
5555
56- def _pocketfft_oaconvolve (Q , T , conv_block_size ):
56+ def _pocketfft_valid_oaconvolve (Q , T , conv_block_size ):
5757 QT_conv_blocks = _pocketfft_oaconvolve_block (Q , T , conv_block_size )
5858 overlap = len (Q ) - 1
5959 out = QT_conv_blocks [:, :- overlap ]
6060 out [1 :, :overlap ] += QT_conv_blocks [:- 1 , - overlap :]
61- return np .reshape (out , (- 1 ,))
6261
62+ return np .reshape (out , (- 1 ,))[len (Q ) - 1 : len (T )]
6363
64- def _sliding_dot_product (Q , T , conv_block_size ):
65- return _pocketfft_oaconvolve (Q [::- 1 ], T , conv_block_size )[len (Q ) - 1 : len (T )]
64+
65+ def _valid_convolve (Q , T , conv_block_size = None ):
66+ m = len (Q )
67+ n = len (T )
68+ conv_block_size = _compute_block_size (m , n , conv_block_size = conv_block_size )
69+ if conv_block_size >= n :
70+ out = pocketfft_r2c_c2r_sdp ._pocketfft_valid_convolve (Q , T )
71+ else :
72+ out = _pocketfft_valid_oaconvolve (Q , T , conv_block_size )
73+
74+ return out
6675
6776
6877def setup (Q , T ):
6978 return
7079
7180
7281def sliding_dot_product (Q , T , conv_block_size = None ):
73- m = Q .shape [0 ]
74- n = T .shape [0 ]
75- if m == n :
82+ if len (Q ) == len (T ):
7683 return np .dot (Q , T )
77-
78- conv_block_size = _compute_block_size (m , n , conv_block_size = conv_block_size )
79- if conv_block_size >= n :
80- return pocketfft_r2c_c2r_sdp .sliding_dot_product (Q , T )
8184 else :
82- return _sliding_dot_product ( Q , T , conv_block_size )
85+ return _valid_convolve ( Q [:: - 1 ] , T , conv_block_size = conv_block_size )
0 commit comments