1717logger = logging .getLogger (__name__ )
1818
1919
20- def neigh_mat (Xd , nskip = 20 , n_neigh = 10 , max_sub = None ):
20+ def neigh_mat (Xd , nskip = 1 , n_neigh = 10 , max_sub = 25000 ):
2121 # Xd is spikes by PCA features in a local neighborhood
2222 # finding n_neigh neighbors of each spike to a subset of every nskip spike
2323
2424 # n_samples is the number of spikes, dim is number of features
2525 n_samples , dim = Xd .shape
2626
27- # subsampling the feature matrix
28- if max_sub is not None :
29- # NOTE: Rather than selecting a fixed-size subset, we adjust nskip.
30- # This is much faster than the alternatives we've tried since it's
31- # more-or-less constant speed for arbitrarily large tensors, and it
32- # keeps the logic simple elsewhere in the code.
33- new_nskip = int (np .ceil ((n_samples - 1 )/ (max_sub - 1 )))
34- if new_nskip > nskip : nskip = new_nskip
27+ # Downsample feature matrix by selecting every `nskip`-th spike
3528 Xsub = Xd [::nskip ]
29+ n1 = Xsub .shape [0 ]
30+ # If the downsampled matrix is still larger than max_sub,
31+ # downsample it further by selecting `max_sub` evenly distributed spikes.
32+ if (max_sub is not None ) and (n1 > max_sub ):
33+ n2 = n1 - max_sub
34+ idx , rev_idx = subsample_idx (n1 , n2 )
35+ Xsub = Xsub [idx ]
36+ else :
37+ rev_idx = None
38+
3639 # n_nodes are the # subsampled spikes
3740 n_nodes = Xsub .shape [0 ]
3841
@@ -55,7 +58,10 @@ def neigh_mat(Xd, nskip=20, n_neigh=10, max_sub=None):
5558 (kn .shape [0 ], n_nodes )) # (shape)
5659
5760 # self connections are set to 0
58- M [np .arange (0 ,n_samples ,nskip ), np .arange (n_nodes )] = 0
61+ skip_idx = np .arange (0 , n_samples , nskip )
62+ if rev_idx is not None :
63+ skip_idx = skip_idx [rev_idx ]
64+ M [skip_idx , np .arange (n_nodes )] = 0
5965
6066 return kn , M
6167
@@ -112,7 +118,7 @@ def Mstats(M, device=torch.device('cuda')):
112118 return m , ki , kj
113119
114120
115- def cluster (Xd , iclust = None , kn = None , nskip = 20 , n_neigh = 10 , max_sub = np . inf ,
121+ def cluster (Xd , iclust = None , kn = None , nskip = 1 , n_neigh = 10 , max_sub = 25000 ,
116122 nclust = 200 , seed = 1 , niter = 200 , lam = 0 , device = torch .device ('cuda' ),
117123 verbose = False ):
118124
0 commit comments