1111logger = logging .getLogger (__name__ )
1212
1313
14- def prepare_extract (ops , U , nC , device = torch .device ('cuda' )):
15- ds = (ops ['xc' ] - ops ['xc' ][:, np .newaxis ])** 2 + (ops ['yc' ] - ops ['yc' ][:, np .newaxis ])** 2
14+ def prepare_extract (xc , yc , U , nC , position_limit , device = torch .device ('cuda' )):
15+ """Identify desired channels based on distances and template norms.
16+
17+ Parameters
18+ ----------
19+ xc : np.ndarray
20+ X-coordinates of contact positions on probe.
21+ yc : np.ndarray
22+ Y-coordinates of contact positions on probe.
23+ U : torch.Tensor
24+ TODO
25+ nC : int
26+ Number of nearest channels to use.
27+ position_limit : float
28+ Max distance (in microns) between channels that are used to estimate
29+ spike positions in `postprocessing.compute_spike_positions`.
30+
31+ Returns
32+ -------
33+ iCC : np.ndarray
34+ For each channel, indices of nC nearest channels.
35+ iCC_mask : np.ndarray
36+ For each channel, a 1 if the channel is within 100um and a 0 otherwise.
37+ Used to control spike position estimate in post-processing.
38+ iU : torch.Tensor
39+ For each template, index of channel with greatest norm.
40+ Ucc : torch.Tensor
41+ For each template, spatial PC features corresponding to iCC.
42+
43+ """
44+ ds = (xc - xc [:, np .newaxis ])** 2 + (yc - yc [:, np .newaxis ])** 2
1645 iCC = np .argsort (ds , 0 )[:nC ]
1746 iCC = torch .from_numpy (iCC ).to (device )
47+ iCC_mask = np .sort (ds , 0 )[:nC ]
48+ iCC_mask = iCC_mask < position_limit ** 2
49+ iCC_mask = torch .from_numpy (iCC_mask ).to (device )
1850 iU = torch .argmax ((U ** 2 ).sum (1 ), - 1 )
1951 Ucc = U [torch .arange (U .shape [0 ]),:,iCC [:,iU ]]
20- return iCC , iU , Ucc
52+
53+ return iCC , iCC_mask , iU , Ucc
54+
2155
2256def extract (ops , bfile , U , device = torch .device ('cuda' ), progress_bar = None ):
2357 nC = ops ['settings' ]['nearest_chans' ]
24- iCC , iU , Ucc = prepare_extract (ops , U , nC , device = device )
58+ position_limit = ops ['settings' ]['position_limit' ]
59+ iCC , iCC_mask , iU , Ucc = prepare_extract (
60+ ops ['xc' ], ops ['yc' ], U , nC , position_limit , device = device
61+ )
2562 ops ['iCC' ] = iCC
63+ ops ['iCC_mask' ] = iCC_mask
2664 ops ['iU' ] = iU
2765 nt = ops ['nt' ]
2866
@@ -85,6 +123,7 @@ def extract(ops, bfile, U, device=torch.device('cuda'), progress_bar=None):
85123
86124 return st , tF , ops
87125
126+
88127def align_U (U , ops , device = torch .device ('cuda' )):
89128 Uex = torch .einsum ('xyz, zt -> xty' , U .to (device ), ops ['wPCA' ])
90129 X = Uex .reshape (- 1 , ops ['Nchan' ]).T
@@ -108,6 +147,7 @@ def postprocess_templates(Wall, ops, clu, st, device=torch.device('cuda')):
108147 Wall3 = Wall3 .transpose (1 ,2 ).to (device )
109148 return Wall3
110149
150+
111151def prepare_matching (ops , U ):
112152 nt = ops ['nt' ]
113153 W = ops ['wPCA' ].contiguous ()
@@ -122,6 +162,7 @@ def prepare_matching(ops, U):
122162
123163 return ctc
124164
165+
125166def run_matching (ops , X , U , ctc , device = torch .device ('cuda' )):
126167 Th = ops ['Th_learned' ]
127168 nt = ops ['nt' ]
0 commit comments