Skip to content

Commit bccf87f

Browse files
Fixed several memory issues in clustering_qr
1 parent eac3c3a commit bccf87f

File tree

2 files changed

+14
-24
lines changed

2 files changed

+14
-24
lines changed

kilosort/clustering_qr.py

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def cluster(Xd, iclust=None, kn=None, nskip=20, n_neigh=10, nclust=200, seed=1,
183183
def kmeans_plusplus(Xg, niter=200, seed=1, device=torch.device('cuda'), verbose=False):
184184
# Xg is number of spikes by number of features.
185185
# We are finding cluster centroids and assigning each spike to a centroid.
186-
vtot = (Xg**2).sum(1)
186+
vtot = torch.norm(Xg, 2, dim=1)**2
187187

188188
n1 = vtot.shape[0]
189189
if n1 > 2**24:
@@ -221,10 +221,6 @@ def kmeans_plusplus(Xg, niter=200, seed=1, device=torch.device('cuda'), verbose=
221221
# v2 is the un-explained variance so far for each spike
222222
v2 = torch.relu(vtot - vexp0)
223223

224-
# TODO: Where to actually apply subsampling of Xg?
225-
# Okay to draw isamp from the subsample? If not, how to reconcile
226-
# vtot size (all of Xg) vs vexp0 size (subsample size)?
227-
228224
# We sample ntry new candidate centroids based on how much un-explained variance they have
229225
# more unexplained variance makes it more likely to be selected
230226
# Only one of these candidates will be added this iteration.
@@ -271,7 +267,9 @@ def kmeans_plusplus(Xg, niter=200, seed=1, device=torch.device('cuda'), verbose=
271267
if verbose:
272268
log_performance(logger, header='clustering_qr.kpp, after loop')
273269

274-
# if the clustering above is done on a subset of Xg, then we need to assign all Xgs here to get an iclust
270+
# NOTE: For very large datasets, we may end up needing to subsample Xg.
271+
# If the clustering above is done on a subset of Xg,
272+
# then we need to assign all Xgs here to get an iclust
275273
# for ii in range((len(Xg)-1)//nblock +1):
276274
# vexp = 2 * Xg[ii*nblock:(ii+1)*nblock] @ mu.T - (mu**2).sum(1)
277275
# iclust[ii*nblock:(ii+1)*nblock] = torch.argmax(vexp, dim=-1)
@@ -349,9 +347,6 @@ def xy_templates(ops):
349347
xy = np.vstack((xcup, ycup))
350348
xy = torch.from_numpy(xy)
351349

352-
iU = ops['iU'].cpu().numpy()
353-
iC = ops['iCC'][:, ops['iU']]
354-
355350
return xy, iC
356351

357352

@@ -495,7 +490,7 @@ def run(ops, st, tF, mode = 'template', device=torch.device('cuda'),
495490

496491
ix = (nearest_center == ii)
497492
ntemp = ix.sum()
498-
Xd, ch_min, ch_max, igood = get_data_cpu(
493+
Xd, igood, ichan = get_data_cpu(
499494
ops, xy, iC, iclust_template, tF, ycent[kk], xcent[jj],
500495
dmin=dmin, dminx=dminx, ix=ix,
501496
)
@@ -551,7 +546,7 @@ def run(ops, st, tF, mode = 'template', device=torch.device('cuda'),
551546
W = torch.zeros((Nfilt, ops['Nchan'], ops['settings']['n_pcs']))
552547
for j in range(Nfilt):
553548
w = Xd[iclust==j].mean(0)
554-
W[j, ch_min:ch_max, :] = torch.reshape(w, (-1, ops['settings']['n_pcs'])).cpu()
549+
W[j, ichan, :] = torch.reshape(w, (-1, ops['settings']['n_pcs'])).cpu()
555550

556551
Wall = torch.cat((Wall, W), 0)
557552

@@ -597,13 +592,11 @@ def get_data_cpu(ops, xy, iC, PID, tF, ycenter, xcenter, dmin=20, dminx=32,
597592
y0 = ycenter # xy[1].mean() - ycenter
598593
x0 = xcenter #xy[0].mean() - xcenter
599594

600-
#print(dmin, dminx)
601595
if ix is None:
602596
ix = torch.logical_and(
603597
torch.abs(xy[1] - y0) < dmin,
604598
torch.abs(xy[0] - x0) < dminx
605599
)
606-
#print(ix.nonzero()[:,0])
607600
igood = ix[PID].nonzero()[:,0]
608601

609602
if len(igood)==0:
@@ -612,25 +605,22 @@ def get_data_cpu(ops, xy, iC, PID, tF, ycenter, xcenter, dmin=20, dminx=32,
612605
pid = PID[igood]
613606
data = tF[igood]
614607
nspikes, nchanraw, nfeatures = data.shape
615-
ichan = torch.unique(iC[:, ix])
616-
ch_min = torch.min(ichan)
617-
ch_max = torch.max(ichan)+1
618-
nchan = ch_max - ch_min
608+
ichan, imap = torch.unique(iC[:, ix], return_inverse=True)
609+
print(ichan)
610+
nchan = ichan.nelement()
619611

620612
dd = torch.zeros((nspikes, nchan, nfeatures))
621-
for j in ix.nonzero()[:,0]:
613+
for k,j in enumerate(ix.nonzero()[:,0]):
622614
ij = torch.nonzero(pid==j)[:, 0]
623-
#print(ij.sum())
624-
dd[ij.unsqueeze(-1), iC[:,j]-ch_min] = data[ij]
615+
dd[ij.unsqueeze(-1), imap[:,k]] = data[ij]
625616

626617
if merge_dim:
627618
Xd = torch.reshape(dd, (nspikes, -1))
628619
else:
629620
# Keep channels and features separate
630621
Xd = dd
631622

632-
return Xd, ch_min, ch_max, igood
633-
623+
return Xd, igood, ichan
634624

635625

636626
def assign_clust(rows_neigh, iclust, kn, tones2, nclust):

kilosort/postprocessing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def make_pc_features(ops, spike_templates, spike_clusters, tF):
102102
ix[iunq] = True
103103
# Get PC features for all spikes detected with those templates (Xd),
104104
# and the indices in tF where those spikes occur (igood).
105-
Xd, ch_min, ch_max, igood = get_data_cpu(
105+
Xd, igood, ichan = get_data_cpu(
106106
ops, xy, iC, spike_templates, tF, None, None,
107107
dmin=ops['dmin'], dminx=ops['dminx'], ix=ix, merge_dim=False
108108
)
@@ -114,7 +114,7 @@ def make_pc_features(ops, spike_templates, spike_clusters, tF):
114114
# Assign features to overwrite tF in-place
115115
tF[igood,:] = Xd[:, ind[:n_chans], :]
116116
# Save channel inds for phy
117-
feature_ind[i,:] = ind[:n_chans].numpy() + ch_min.cpu().numpy()
117+
feature_ind[i,:] = ichan[ind[:n_chans]].cpu().numpy()
118118

119119
# Swap last 2 dimensions to get ordering Phy expects
120120
tF = torch.permute(tF, (0, 2, 1))

0 commit comments

Comments
 (0)