1111logger = logging .getLogger (__name__ )
1212
1313
14- def prepare_extract (ops , U , nC , device = torch .device ('cuda' )):
15- ds = (ops ['xc' ] - ops ['xc' ][:, np .newaxis ])** 2 + (ops ['yc' ] - ops ['yc' ][:, np .newaxis ])** 2
14+ def prepare_extract (xc , yc , U , nC , position_limit , device = torch .device ('cuda' )):
15+ """Identify desired channels based on distances and template norms.
16+
17+ Parameters
18+ ----------
19+ xc : np.ndarray
20+ X-coordinates of contact positions on probe.
21+ yc : np.ndarray
22+ Y-coordinates of contact positions on probe.
23+ U : torch.Tensor
24+ TODO
25+ nC : int
26+ Number of nearest channels to use.
27+ position_limit : float
28+ Max distance (in microns) between channels that are used to estimate
29+ spike positions in `postprocessing.compute_spike_positions`.
30+
31+ Returns
32+ -------
33+ iCC : np.ndarray
34+ For each channel, indices of nC nearest channels.
35+ iCC_mask : np.ndarray
36+ For each channel, a 1 if the channel is within 100um and a 0 otherwise.
37+ Used to control spike position estimate in post-processing.
38+ iU : torch.Tensor
39+ For each template, index of channel with greatest norm.
40+ Ucc : torch.Tensor
41+ For each template, spatial PC features corresponding to iCC.
42+
43+ """
44+ ds = (xc - xc [:, np .newaxis ])** 2 + (yc - yc [:, np .newaxis ])** 2
1645 iCC = np .argsort (ds , 0 )[:nC ]
17- iCC = torch .from_numpy (iCC , device = device )
18- iCC_mask = np .sorg (ds , 0 )[:nC ]
19- iCC_mask = iCC_mask < 10000 # 100um squared
20- iCC_mask = torch .from_numpy (iCC_mask , device = device )
46+ iCC = torch .from_numpy (iCC ). to ( device )
47+ iCC_mask = np .sort (ds , 0 )[:nC ]
48+ iCC_mask = iCC_mask < position_limit ** 2
49+ iCC_mask = torch .from_numpy (iCC_mask ). to ( device )
2150 iU = torch .argmax ((U ** 2 ).sum (1 ), - 1 )
2251 Ucc = U [torch .arange (U .shape [0 ]),:,iCC [:,iU ]]
2352
24- # iCC: nC nearest channels to each channel
25- # iCC_mask: 1 if above is within 100um of channel, 0 otherwise
26- # iU: index of max channel for each template
27- # Ucc: spatial PC features corresponding to iCC for each template
28-
2953 return iCC , iCC_mask , iU , Ucc
3054
55+
3156def extract (ops , bfile , U , device = torch .device ('cuda' ), progress_bar = None ):
3257 nC = ops ['settings' ]['nearest_chans' ]
33- iCC , iCC_mask , iU , Ucc = prepare_extract (ops , U , nC , device = device )
58+ position_limit = ops ['settings' ]['position_limit' ]
59+ iCC , iCC_mask , iU , Ucc = prepare_extract (
60+ ops ['xc' ], ops ['yc' ], U , nC , position_limit , device = device
61+ )
3462 ops ['iCC' ] = iCC
3563 ops ['iCC_mask' ] = iCC_mask
3664 ops ['iU' ] = iU
@@ -95,6 +123,7 @@ def extract(ops, bfile, U, device=torch.device('cuda'), progress_bar=None):
95123
96124 return st , tF , ops
97125
126+
98127def align_U (U , ops , device = torch .device ('cuda' )):
99128 Uex = torch .einsum ('xyz, zt -> xty' , U .to (device ), ops ['wPCA' ])
100129 X = Uex .reshape (- 1 , ops ['Nchan' ]).T
@@ -118,6 +147,7 @@ def postprocess_templates(Wall, ops, clu, st, device=torch.device('cuda')):
118147 Wall3 = Wall3 .transpose (1 ,2 ).to (device )
119148 return Wall3
120149
150+
121151def prepare_matching (ops , U ):
122152 nt = ops ['nt' ]
123153 W = ops ['wPCA' ].contiguous ()
@@ -132,6 +162,7 @@ def prepare_matching(ops, U):
132162
133163 return ctc
134164
165+
135166def run_matching (ops , X , U , ctc , device = torch .device ('cuda' )):
136167 Th = ops ['Th_learned' ]
137168 nt = ops ['nt' ]
0 commit comments