Skip to content

Commit aa4c879

Browse files
Merge pull request #962 from MouseLand/jacob/max_sub_update
Jacob/max sub update
2 parents 8d2b46d + 7af6f3e commit aa4c879

File tree

3 files changed

+34
-17
lines changed

3 files changed

+34
-17
lines changed

kilosort/clustering_qr.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,25 @@
1717
logger = logging.getLogger(__name__)
1818

1919

20-
def neigh_mat(Xd, nskip=20, n_neigh=10, max_sub=None):
20+
def neigh_mat(Xd, nskip=1, n_neigh=10, max_sub=25000):
2121
# Xd is spikes by PCA features in a local neighborhood
2222
# finding n_neigh neighbors of each spike to a subset of every nskip spike
2323

2424
# n_samples is the number of spikes, dim is number of features
2525
n_samples, dim = Xd.shape
2626

27-
# subsampling the feature matrix
28-
if max_sub is not None:
29-
# NOTE: Rather than selecting a fixed-size subset, we adjust nskip.
30-
# This is much faster than the alternatives we've tried since it's
31-
# more-or-less constant speed for arbitrarily large tensors, and it
32-
# keeps the logic simple elsewhere in the code.
33-
new_nskip = int(np.ceil((n_samples-1)/(max_sub-1)))
34-
if new_nskip > nskip: nskip = new_nskip
27+
# Downsample feature matrix by selecting every `nskip`-th spike
3528
Xsub = Xd[::nskip]
29+
n1 = Xsub.shape[0]
30+
# If the downsampled matrix is still larger than max_sub,
31+
# downsample it further by selecting `max_sub` evenly distributed spikes.
32+
if (max_sub is not None) and (n1 > max_sub):
33+
n2 = n1 - max_sub
34+
idx, rev_idx = subsample_idx(n1, n2)
35+
Xsub = Xsub[idx]
36+
else:
37+
rev_idx = None
38+
3639
# n_nodes are the # subsampled spikes
3740
n_nodes = Xsub.shape[0]
3841

@@ -55,7 +58,10 @@ def neigh_mat(Xd, nskip=20, n_neigh=10, max_sub=None):
5558
(kn.shape[0], n_nodes)) # (shape)
5659

5760
# self connections are set to 0
58-
M[np.arange(0,n_samples,nskip), np.arange(n_nodes)] = 0
61+
skip_idx = np.arange(0, n_samples, nskip)
62+
if rev_idx is not None:
63+
skip_idx = skip_idx[rev_idx]
64+
M[skip_idx, np.arange(n_nodes)] = 0
5965

6066
return kn, M
6167

@@ -112,7 +118,7 @@ def Mstats(M, device=torch.device('cuda')):
112118
return m, ki, kj
113119

114120

115-
def cluster(Xd, iclust=None, kn=None, nskip=20, n_neigh=10, max_sub=np.inf,
121+
def cluster(Xd, iclust=None, kn=None, nskip=1, n_neigh=10, max_sub=25000,
116122
nclust=200, seed=1, niter=200, lam=0, device=torch.device('cuda'),
117123
verbose=False):
118124

kilosort/parameters.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -384,17 +384,21 @@
384384

385385
'cluster_downsampling': {
386386
'gui_name': 'cluster downsampling', 'type': int, 'min': 1, 'max': np.inf,
387-
'exclude': [], 'default': 20, 'step': 'clustering',
387+
'exclude': [], 'default': 1, 'step': 'clustering',
388388
'description':
389389
"""
390-
Inverse fraction of nodes used as landmarks during clustering
391-
(can be 1, but that slows down the optimization).
390+
Inverse fraction of spikes used as landmarks during clustering. By
391+
default, all spikes are used up to a maximum of
392+
`max_cluster_subset=25000`.
393+
394+
The old default behavior (version < 4.1.0) is
395+
equivalent to `max_cluster_subset=None, cluster_downsampling=20`.
392396
"""
393397
},
394398

395399
'max_cluster_subset': {
396400
'gui_name': 'max cluster subset', 'type': int, 'min': 1, 'max': np.inf,
397-
'exclude': [], 'default': None, 'step': 'clustering',
401+
'exclude': [np.inf], 'default': 25000, 'step': 'clustering',
398402
'description':
399403
"""
400404
Maximum number of spikes to use when searching for nearest neighbors
@@ -405,13 +409,16 @@
405409
bound for very long recordings. Using a very large number of spikes
406410
is not necessary and causes performance bottlenecks.
407411
412+
Use `max_cluster_subset = None` if you do not want a limit on
413+
the subset size. The old default behavior (version < 4.1.0) is
414+
equivalent to `max_cluster_subset=None, cluster_downsampling=20`.
415+
408416
Note: In practice, the actual number of spikes used may increase or
409417
decrease slightly while staying under the maximum. This happens
410418
because the maximum is set by adjusting `cluster_downsampling` on the
411419
fly so that it results in a set no larger than the given size.
412420
"""
413421
},
414-
# TODO: Add suggested values after more testing on different datasets.
415422

416423
'x_centers': {
417424
'gui_name': 'x centers', 'type': int, 'min': 1,

tests/test_full_pipeline.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,14 @@ def test_pipeline(data_directory, results_directory, saved_ops, torch_device, ca
1717
)
1818

1919
with capture_mgr.global_and_fixture_disabled():
20+
# NOTE: 'cluster_downsampling' and 'max_cluster_subset' are set to be
21+
# equivalent to their default behavior prior to version 4.1.0,
22+
# since that was how the test results were generated.
2023
print('\nStarting run_kilosort test...')
2124
ops, st, clu, _, _, _, _, _, kept_spikes = run_kilosort(
2225
filename=bin_file, device=torch_device,
23-
settings={'n_chan_bin': 385},
26+
settings={'n_chan_bin': 385, 'cluster_downsampling': 20,
27+
'max_cluster_subset': None},
2428
probe_name='NeuroPix1_default.mat',
2529
verbose_console=True
2630
)

0 commit comments

Comments
 (0)