Skip to content

Commit 9233270

Browse files
Added fix for spike smearing between shanks
1 parent 5bb4604 commit 9233270

File tree

3 files changed

+64
-17
lines changed

3 files changed

+64
-17
lines changed

kilosort/parameters.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,18 @@
398398
default of 7 bins for a 30kHz sampling rate.
399399
"""
400400
},
401+
402+
'position_limit': {
403+
'gui_name': 'position limit', 'type': float, 'min': 0, 'max': np.inf,
404+
'exclude': [], 'default': 100, 'step': 'postprocessing',
405+
'description':
406+
"""
407+
Maximum distance (in microns) between channels that can be used
408+
to estimate spike positions in `postprocessing.compute_spike_positions`.
409+
This does not affect spike sorting, only how positions are estimated
410+
after sorting is complete.
411+
"""
412+
},
401413
}
402414

403415
# Add default values to descriptions

kilosort/postprocessing.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,22 @@ def remove_duplicates(spike_times, spike_clusters, dt=15):
3232

3333
def compute_spike_positions(st, tF, ops):
3434
'''Get x,y positions of spikes relative to probe.'''
35+
# Determine channel weightings for nearest channels
36+
# based on norm of PC features. Channels that are far away have 0 weight,
37+
# determined by `ops['settings']['position_limit']`.
3538
tmass = (tF**2).sum(-1)
36-
xc = torch.from_numpy(ops['xc']).to(tmass.device)
37-
yc = torch.from_numpy(ops['yc']).to(tmass.device)
38-
# TODO: also store distance to each of these channels, and multiply by
39-
# tmass before summing so that far away channels are ~0
40-
tmask = ops['iCC_mask'][:, ops['iU'][st[:,1]]] # 1 if close enough, 0 if too far away (tbd, maybe 100ish um)
39+
tmask = ops['iCC_mask'][:, ops['iU'][st[:,1]]].T.to(tmass.device)
4140
tmass = tmass * tmask
4241
tmass = tmass / tmass.sum(1, keepdim=True)
42+
43+
# Get x,y coordinates of nearest channels.
44+
xc = torch.from_numpy(ops['xc']).to(tmass.device)
45+
yc = torch.from_numpy(ops['yc']).to(tmass.device)
4346
chs = ops['iCC'][:, ops['iU'][st[:,1]]].cpu()
4447
xc0 = xc[chs.T]
4548
yc0 = yc[chs.T]
4649

50+
# Estimate spike positions as weighted sum of coordinates of nearby channels.
4751
xs = (xc0 * tmass).sum(1).cpu().numpy()
4852
ys = (yc0 * tmass).sum(1).cpu().numpy()
4953

kilosort/template_matching.py

Lines changed: 43 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,26 +11,54 @@
1111
logger = logging.getLogger(__name__)
1212

1313

14-
def prepare_extract(ops, U, nC, device=torch.device('cuda')):
15-
ds = (ops['xc'] - ops['xc'][:, np.newaxis])**2 + (ops['yc'] - ops['yc'][:, np.newaxis])**2
14+
def prepare_extract(xc, yc, U, nC, position_limit, device=torch.device('cuda')):
15+
"""Identify desired channels based on distances and template norms.
16+
17+
Parameters
18+
----------
19+
xc : np.ndarray
20+
X-coordinates of contact positions on probe.
21+
yc : np.ndarray
22+
Y-coordinates of contact positions on probe.
23+
U : torch.Tensor
24+
TODO
25+
nC : int
26+
Number of nearest channels to use.
27+
position_limit : float
28+
Max distance (in microns) between channels that are used to estimate
29+
spike positions in `postprocessing.compute_spike_positions`.
30+
31+
Returns
32+
-------
33+
iCC : np.ndarray
34+
For each channel, indices of nC nearest channels.
35+
iCC_mask : np.ndarray
36+
For each channel, a 1 if the channel is within 100um and a 0 otherwise.
37+
Used to control spike position estimate in post-processing.
38+
iU : torch.Tensor
39+
For each template, index of channel with greatest norm.
40+
Ucc : torch.Tensor
41+
For each template, spatial PC features corresponding to iCC.
42+
43+
"""
44+
ds = (xc - xc[:, np.newaxis])**2 + (yc - yc[:, np.newaxis])**2
1645
iCC = np.argsort(ds, 0)[:nC]
17-
iCC = torch.from_numpy(iCC, device=device)
18-
iCC_mask = np.sorg(ds, 0)[:nC]
19-
iCC_mask = iCC_mask < 10000 # 100um squared
20-
iCC_mask = torch.from_numpy(iCC_mask, device=device)
46+
iCC = torch.from_numpy(iCC).to(device)
47+
iCC_mask = np.sort(ds, 0)[:nC]
48+
iCC_mask = iCC_mask < position_limit**2
49+
iCC_mask = torch.from_numpy(iCC_mask).to(device)
2150
iU = torch.argmax((U**2).sum(1), -1)
2251
Ucc = U[torch.arange(U.shape[0]),:,iCC[:,iU]]
2352

24-
# iCC: nC nearest channels to each channel
25-
# iCC_mask: 1 if above is within 100um of channel, 0 otherwise
26-
# iU: index of max channel for each template
27-
# Ucc: spatial PC features corresponding to iCC for each template
28-
2953
return iCC, iCC_mask, iU, Ucc
3054

55+
3156
def extract(ops, bfile, U, device=torch.device('cuda'), progress_bar=None):
3257
nC = ops['settings']['nearest_chans']
33-
iCC, iCC_mask, iU, Ucc = prepare_extract(ops, U, nC, device=device)
58+
position_limit = ops['settings']['position_limit']
59+
iCC, iCC_mask, iU, Ucc = prepare_extract(
60+
ops['xc'], ops['yc'], U, nC, position_limit, device=device
61+
)
3462
ops['iCC'] = iCC
3563
ops['iCC_mask'] = iCC_mask
3664
ops['iU'] = iU
@@ -95,6 +123,7 @@ def extract(ops, bfile, U, device=torch.device('cuda'), progress_bar=None):
95123

96124
return st, tF, ops
97125

126+
98127
def align_U(U, ops, device=torch.device('cuda')):
99128
Uex = torch.einsum('xyz, zt -> xty', U.to(device), ops['wPCA'])
100129
X = Uex.reshape(-1, ops['Nchan']).T
@@ -118,6 +147,7 @@ def postprocess_templates(Wall, ops, clu, st, device=torch.device('cuda')):
118147
Wall3 = Wall3.transpose(1,2).to(device)
119148
return Wall3
120149

150+
121151
def prepare_matching(ops, U):
122152
nt = ops['nt']
123153
W = ops['wPCA'].contiguous()
@@ -132,6 +162,7 @@ def prepare_matching(ops, U):
132162

133163
return ctc
134164

165+
135166
def run_matching(ops, X, U, ctc, device=torch.device('cuda')):
136167
Th = ops['Th_learned']
137168
nt = ops['nt']

0 commit comments

Comments
 (0)