@@ -183,7 +183,7 @@ def cluster(Xd, iclust=None, kn=None, nskip=20, n_neigh=10, nclust=200, seed=1,
183183def kmeans_plusplus (Xg , niter = 200 , seed = 1 , device = torch .device ('cuda' ), verbose = False ):
184184 # Xg is number of spikes by number of features.
185185 # We are finding cluster centroids and assigning each spike to a centroid.
186- vtot = (Xg ** 2 ). sum ( 1 )
186+ vtot = torch . norm (Xg , 2 , dim = 1 ) ** 2
187187
188188 n1 = vtot .shape [0 ]
189189 if n1 > 2 ** 24 :
@@ -221,10 +221,6 @@ def kmeans_plusplus(Xg, niter=200, seed=1, device=torch.device('cuda'), verbose=
221221 # v2 is the un-explained variance so far for each spike
222222 v2 = torch .relu (vtot - vexp0 )
223223
224- # TODO: Where to actually apply subsampling of Xg?
225- # Okay to draw isamp from the subsample? If not, how to reconcile
226- # vtot size (all of Xg) vs vexp0 size (subsample size)?
227-
228224 # We sample ntry new candidate centroids based on how much un-explained variance they have
229225 # more unexplained variance makes it more likely to be selected
230226 # Only one of these candidates will be added this iteration.
@@ -271,7 +267,9 @@ def kmeans_plusplus(Xg, niter=200, seed=1, device=torch.device('cuda'), verbose=
271267 if verbose :
272268 log_performance (logger , header = 'clustering_qr.kpp, after loop' )
273269
274- # if the clustering above is done on a subset of Xg, then we need to assign all Xgs here to get an iclust
270+ # NOTE: For very large datasets, we may end up needing to subsample Xg.
271+ # If the clustering above is done on a subset of Xg,
272+ # then we need to assign all Xgs here to get an iclust
275273 # for ii in range((len(Xg)-1)//nblock +1):
276274 # vexp = 2 * Xg[ii*nblock:(ii+1)*nblock] @ mu.T - (mu**2).sum(1)
277275 # iclust[ii*nblock:(ii+1)*nblock] = torch.argmax(vexp, dim=-1)
@@ -349,9 +347,6 @@ def xy_templates(ops):
349347 xy = np .vstack ((xcup , ycup ))
350348 xy = torch .from_numpy (xy )
351349
352- iU = ops ['iU' ].cpu ().numpy ()
353- iC = ops ['iCC' ][:, ops ['iU' ]]
354-
355350 return xy , iC
356351
357352
@@ -495,7 +490,7 @@ def run(ops, st, tF, mode = 'template', device=torch.device('cuda'),
495490
496491 ix = (nearest_center == ii )
497492 ntemp = ix .sum ()
498- Xd , ch_min , ch_max , igood = get_data_cpu (
493+ Xd , igood , ichan = get_data_cpu (
499494 ops , xy , iC , iclust_template , tF , ycent [kk ], xcent [jj ],
500495 dmin = dmin , dminx = dminx , ix = ix ,
501496 )
@@ -551,7 +546,7 @@ def run(ops, st, tF, mode = 'template', device=torch.device('cuda'),
551546 W = torch .zeros ((Nfilt , ops ['Nchan' ], ops ['settings' ]['n_pcs' ]))
552547 for j in range (Nfilt ):
553548 w = Xd [iclust == j ].mean (0 )
554- W [j , ch_min : ch_max , :] = torch .reshape (w , (- 1 , ops ['settings' ]['n_pcs' ])).cpu ()
549+ W [j , ichan , :] = torch .reshape (w , (- 1 , ops ['settings' ]['n_pcs' ])).cpu ()
555550
556551 Wall = torch .cat ((Wall , W ), 0 )
557552
@@ -597,13 +592,11 @@ def get_data_cpu(ops, xy, iC, PID, tF, ycenter, xcenter, dmin=20, dminx=32,
597592 y0 = ycenter # xy[1].mean() - ycenter
598593 x0 = xcenter #xy[0].mean() - xcenter
599594
600- #print(dmin, dminx)
601595 if ix is None :
602596 ix = torch .logical_and (
603597 torch .abs (xy [1 ] - y0 ) < dmin ,
604598 torch .abs (xy [0 ] - x0 ) < dminx
605599 )
606- #print(ix.nonzero()[:,0])
607600 igood = ix [PID ].nonzero ()[:,0 ]
608601
609602 if len (igood )== 0 :
@@ -612,25 +605,22 @@ def get_data_cpu(ops, xy, iC, PID, tF, ycenter, xcenter, dmin=20, dminx=32,
612605 pid = PID [igood ]
613606 data = tF [igood ]
614607 nspikes , nchanraw , nfeatures = data .shape
615- ichan = torch .unique (iC [:, ix ])
616- ch_min = torch .min (ichan )
617- ch_max = torch .max (ichan )+ 1
618- nchan = ch_max - ch_min
608+ ichan , imap = torch .unique (iC [:, ix ], return_inverse = True )
609+ print (ichan )
610+ nchan = ichan .nelement ()
619611
620612 dd = torch .zeros ((nspikes , nchan , nfeatures ))
621- for j in ix .nonzero ()[:,0 ]:
613+ for k , j in enumerate ( ix .nonzero ()[:,0 ]) :
622614 ij = torch .nonzero (pid == j )[:, 0 ]
623- #print(ij.sum())
624- dd [ij .unsqueeze (- 1 ), iC [:,j ]- ch_min ] = data [ij ]
615+ dd [ij .unsqueeze (- 1 ), imap [:,k ]] = data [ij ]
625616
626617 if merge_dim :
627618 Xd = torch .reshape (dd , (nspikes , - 1 ))
628619 else :
629620 # Keep channels and features separate
630621 Xd = dd
631622
632- return Xd , ch_min , ch_max , igood
633-
623+ return Xd , igood , ichan
634624
635625
636626def assign_clust (rows_neigh , iclust , kn , tones2 , nclust ):
0 commit comments