Skip to content

Commit 72813b6

Browse files
committed
More speed
1 parent 86b993a commit 72813b6

1 file changed

Lines changed: 45 additions & 4 deletions

File tree

mne_rsa/rdm.py

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ def compute_rdm(data, metric="correlation", **kwargs):
4141
4242
Notes
4343
-----
44-
The distance metrics "euclidean" and "sqeuclidean" will use a custom implementation
45-
instead of :func:`scipy.spatial.distance.pdist`.
44+
The distance metrics "euclidean", "sqeuclidean", "cosine", and "correlation" will
45+
use a custom implementation instead of :func:`scipy.spatial.distance.pdist`.
4646
4747
See Also
4848
--------
@@ -66,6 +66,10 @@ def compute_rdm(data, metric="correlation", **kwargs):
6666
rdm = _sqeuclidean(X)
6767
elif metric == "euclidean":
6868
rdm = np.sqrt(_sqeuclidean(X))
69+
elif metric == "cosine":
70+
rdm = _cosine(X)
71+
elif metric == "correlation":
72+
rdm = _correlation(X)
6973
else:
7074
# Use scikit learn's distance computation.
7175
rdm = distance.pdist(X, metric=metric, **kwargs)
@@ -113,8 +117,9 @@ def compute_rdm_cv(folds, metric="correlation", **kwargs):
113117
114118
Notes
115119
-----
116-
The distance metrics "euclidean", "sqeuclidean" and "crossnobis" will use a custom
117-
implementation instead of :func:`scipy.spatial.distance.pdist`.
120+
The distance metrics "euclidean", "sqeuclidean", "crossnobis", "cosine", and
121+
"correlation" will use a custom implementation instead of
122+
:func:`scipy.spatial.distance.pdist`.
118123
119124
"""
120125
X = np.reshape(folds, (folds.shape[0], folds.shape[1], -1))
@@ -179,6 +184,42 @@ def _sqeuclidean(X):
179184
return D
180185

181186

187+
def _cosine(X):
188+
"""Fast item-to-item cosine distance.
189+
190+
Parameters
191+
----------
192+
X : ndarray, shape (n_items, n_features)
193+
For each item, all the features.
194+
195+
Returns
196+
-------
197+
D : ndarray, shape (n_items * n_items)
198+
The item-to-item distance matrix
199+
200+
"""
201+
return _sqeuclidean(X / np.linalg.norm(X, axis=1, keepdims=True))
202+
203+
204+
def _correlation(X):
205+
"""Fast item-to-item Pearson correlation distance.
206+
207+
Parameters
208+
----------
209+
X : ndarray, shape (n_items, n_features)
210+
For each item, all the features.
211+
212+
Returns
213+
-------
214+
D : ndarray, shape (n_items * n_items)
215+
The item-to-item distance matrix
216+
217+
"""
218+
X = X - X.mean(axis=1, keepdims=True)
219+
X /= np.linalg.norm(X, axis=1, keepdims=True)
220+
return _sqeuclidean(X)
221+
222+
182223
def _crossnobis(X):
183224
"""Fast cross-validated item-to-item squared Euclidean distance (crossnobis).
184225

0 commit comments

Comments
 (0)