Skip to content

Commit f6fed15

Browse files
committed
renaming private functions to reflect valid convolution
1 parent 39e936c commit f6fed15

File tree

2 files changed

+20
-15
lines changed

2 files changed

+20
-15
lines changed

sdp/challenger_sdp.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -53,30 +53,33 @@ def _pocketfft_oaconvolve_block(Q, T, conv_block_size):
5353
return c2r(False, np.multiply(fft_2d[:-1], fft_2d[[-1]]), n=conv_block_size)
5454

5555

56-
def _pocketfft_oaconvolve(Q, T, conv_block_size):
56+
def _pocketfft_valid_oaconvolve(Q, T, conv_block_size):
5757
QT_conv_blocks = _pocketfft_oaconvolve_block(Q, T, conv_block_size)
5858
overlap = len(Q) - 1
5959
out = QT_conv_blocks[:, :-overlap]
6060
out[1:, :overlap] += QT_conv_blocks[:-1, -overlap:]
61-
return np.reshape(out, (-1,))
6261

62+
return np.reshape(out, (-1,))[len(Q) - 1 : len(T)]
6363

64-
def _sliding_dot_product(Q, T, conv_block_size):
65-
return _pocketfft_oaconvolve(Q[::-1], T, conv_block_size)[len(Q) - 1 : len(T)]
64+
65+
def _valid_convolve(Q, T, conv_block_size=None):
66+
m = len(Q)
67+
n = len(T)
68+
conv_block_size = _compute_block_size(m, n, conv_block_size=conv_block_size)
69+
if conv_block_size >= n:
70+
out = pocketfft_r2c_c2r_sdp._pocketfft_valid_convolve(Q, T)
71+
else:
72+
out = _pocketfft_valid_oaconvolve(Q, T, conv_block_size)
73+
74+
return out
6675

6776

6877
def setup(Q, T):
6978
return
7079

7180

7281
def sliding_dot_product(Q, T, conv_block_size=None):
73-
m = Q.shape[0]
74-
n = T.shape[0]
75-
if m == n:
82+
if len(Q) == len(T):
7683
return np.dot(Q, T)
77-
78-
conv_block_size = _compute_block_size(m, n, conv_block_size=conv_block_size)
79-
if conv_block_size >= n:
80-
return pocketfft_r2c_c2r_sdp.sliding_dot_product(Q, T)
8184
else:
82-
return _sliding_dot_product(Q, T, conv_block_size)
85+
return _valid_convolve(Q[::-1], T, conv_block_size=conv_block_size)

sdp/pocketfft_r2c_c2r_sdp.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from scipy.fft._pocketfft.basic import r2c, c2r
44

55

6-
def _pocketfft_convolve(Q, T):
6+
def _pocketfft_valid_convolve(Q, T):
77
n = len(T)
88
m = len(Q)
99
next_fast_n = next_fast_len(n, real=True)
@@ -15,12 +15,14 @@ def _pocketfft_convolve(Q, T):
1515
tmp[1, n:] = 0.0
1616
fft_2d = r2c(True, tmp, axis=-1)
1717

18-
return c2r(False, np.multiply(fft_2d[0], fft_2d[1]), n=next_fast_n)
18+
return c2r(False, np.multiply(fft_2d[0], fft_2d[1]), n=next_fast_n)[
19+
len(Q) - 1 : len(T)
20+
]
1921

2022

2123
def setup(Q, T):
2224
return
2325

2426

2527
def sliding_dot_product(Q, T):
26-
return _pocketfft_convolve(Q[::-1], T)[len(Q) - 1 : len(T)]
28+
return _pocketfft_valid_convolve(Q[::-1], T)

0 commit comments

Comments
 (0)