Skip to content

Commit efe699a

Browse files
committed
two array combinations fix
1 parent 598d885 commit efe699a

File tree

2 files changed

+12
-9
lines changed

2 files changed

+12
-9
lines changed

neurodsp/rhythm/phase.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,15 @@ def pairwise_phase_consistency(pha0, pha1=None, return_pairs=True, progress=None
6363

6464
else:
6565

66-
n_combs = int((len(pha0) * (len(pha0) + 1)) / 2)
66+
# Include all combinations
67+
n_combs = int(len(pha0) ** 2)
6768

68-
# Include self-combinations
69-
iterable = enumerate(combinations_with_replacement(np.arange(len(pha0)), 2))
69+
iterable = enumerate((row, col) for row in range(len(pha0)) for col in range(len(pha1)))
7070

7171
# Initialize variables
7272
if return_pairs:
7373
cumulative = None
74-
distances = np.zeros(n_combs)
74+
distances = np.ones((len(pha0), len(pha0)))
7575
else:
7676
cumulative = 0
7777
distances = None
@@ -88,7 +88,7 @@ def pairwise_phase_consistency(pha0, pha1=None, return_pairs=True, progress=None
8888
# Compute distance indices
8989
for idx, pair in iterable:
9090

91-
phi0= pha0[pair[0]]
91+
phi0 = pha0[pair[0]]
9292

9393
if pha1 is None:
9494
phi1 = pha0[pair[1]]
@@ -108,8 +108,11 @@ def pairwise_phase_consistency(pha0, pha1=None, return_pairs=True, progress=None
108108
# Pairwise circular distance index (PCDI)
109109
distance = (np.pi - 2 * abs_dist) / np.pi
110110

111-
if isinstance(distances, np.ndarray):
112-
distances[idx] = distance
111+
if isinstance(distances, np.ndarray) and pha1 is None:
112+
distances[pair[0], pair[1]] = distance
113+
distances[pair[1], pair[0]] = distance
114+
elif isinstance(distances, np.ndarray) and pha1 is not None:
115+
distances[pair[0], pair[1]] = distance
113116
else:
114117
cumulative += distance
115118

neurodsp/tests/rhythm/test_phase.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def test_pairwise_phase_consistency(tsig_sine, return_pairs, phase_shift):
3232
dist_avg, dists = dist_avg[0], dist_avg[1]
3333

3434
assert isinstance(dists, np.ndarray)
35-
assert len(dists) == (len(peaks) * (len(peaks) + 1)) / 2
35+
assert len(dists[0]) * len(dists[1]) == len(peaks) ** 2
3636
assert np.mean(dists) == dist_avg
3737

3838
# Expected consistency
@@ -50,7 +50,7 @@ def test_pairwise_phase_consistency(tsig_sine, return_pairs, phase_shift):
5050
dist_avg, dists = pairwise_phase_consistency(pha0[peaks], return_pairs=True)
5151

5252
assert dist_avg == 1
53-
assert len(dists) == (len(peaks) * (len(peaks) - 1)) / 2
53+
assert len(dists[0]) == len(dists[1]) == len(peaks)
5454

5555
# Cases where arrays are invalid sizes
5656
try:

0 commit comments

Comments
 (0)