Skip to content

Commit aaf4f4e

Browse files
committed
Datashift using rigid registration (old version - see next commit)
1 parent b4a4949 commit aaf4f4e

File tree

5 files changed

+941
-0
lines changed

5 files changed

+941
-0
lines changed

pykilosort/cluster.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,34 @@
1212
logger = logging.getLogger(__name__)
1313

1414

15+
def getClosestChannels2(ycup, xcup, yc, xc, NchanClosest):
16+
# this function outputs the closest channels to each channel,
17+
# as well as a Gaussian-decaying mask as a function of pairwise distances
18+
# sigma is the standard deviation of this Gaussian-mask
19+
20+
# compute distances between all pairs of channels
21+
xc = cp.asarray(probe.xc, dtype=np.float32, order='F')
22+
yc = cp.asarray(probe.yc, dtype=np.float32, order='F')
23+
xcup = cp.asarray(xcup, dtype=np.float32, order='F')
24+
ycup = cp.asarray(ycup, dtype=np.float32, order='F')
25+
C2C = (xc[:] - xcup[:].T)^2 + (yc[:] - ycup[:].T).^2
26+
C2C = cp.sqrt(C2C)
27+
Nchan, NchanUp C2C.shape
28+
29+
# sort distances
30+
isort = cp.argsort(C2C, axis=0)
31+
32+
# take NchanCLosest neighbors for each primary channel
33+
iC = isort[:NchanClosest, :]
34+
35+
# in some cases we want a mask that decays as a function of distance between pairs of channels
36+
# this is an awkward indexing to get the corresponding distances
37+
ix = iC + cp.arange(0, Nchan * NchanUp, Nchan)
38+
dist = C2C[ix]
39+
40+
return iC, dist
41+
42+
1543
def getClosestChannels(probe, sigma, NchanClosest):
1644
# this function outputs the closest channels to each channel,
1745
# as well as a Gaussian-decaying mask as a function of pairwise distances

pykilosort/cuda/spikedetector3.cu

Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
const int Nthreads = 1024, NrankMax = 6, maxFR = 10000, nt0max=81, NchanMax = 17, nsizes = 5;
2+
3+
4+
//////////////////////////////////////////////////////////////////////////////////////////
5+
__global__ void Conv1D(const double *Params, const float *data, const float *W, float *conv_sig){
6+
volatile __shared__ float sW[81*NrankMax], sdata[(Nthreads+81)];
7+
float y;
8+
int tid, tid0, bid, i, nid, Nrank, NT, nt0, Nchan;
9+
10+
tid = threadIdx.x;
11+
bid = blockIdx.x;
12+
13+
NT = (int) Params[0];
14+
Nchan = (int) Params[1];
15+
nt0 = (int) Params[2];
16+
Nrank = (int) Params[4];
17+
18+
if(tid<nt0*Nrank)
19+
sW[tid]= W[tid];
20+
__syncthreads();
21+
22+
tid0 = 0;
23+
while (tid0<NT-Nthreads-nt0+1){
24+
if (tid<nt0)
25+
sdata[tid] = data[tid0 + tid + NT*bid];
26+
sdata[tid + nt0] = data[nt0+tid0 + tid+ NT*bid];
27+
__syncthreads();
28+
29+
for(nid=0;nid<Nrank;nid++){
30+
y = 0.0f;
31+
#pragma unroll 4
32+
for(i=0;i<nt0;i++)
33+
y += sW[i + nid*nt0] * sdata[i+tid];
34+
conv_sig[tid0 + tid + NT*bid + nid * NT * Nchan] = y;
35+
}
36+
tid0+=Nthreads;
37+
__syncthreads();
38+
}
39+
}
40+
41+
//////////////////////////////////////////////////////////////////////////////////////////
42+
__global__ void sumChannels(const double *Params, const float *data,
43+
float *datasum, int *kkmax, const int *iC2, const float *dist, const float *v2){
44+
45+
int tid, tid0,t,k, kmax, bidx, bidy, NT, Nchan, NchanNear,j,iChan, Nsum, Nrank;
46+
float Cmax, C0;
47+
float a[nsizes], d2;
48+
float sigma;
49+
volatile __shared__ float sA[nsizes * 20];
50+
51+
52+
tid = threadIdx.x;
53+
bidx = blockIdx.x;
54+
bidy = blockIdx.y;
55+
NT = (int) Params[0];
56+
Nchan = (int) Params[1];
57+
NchanNear = (int) Params[3];
58+
Nrank = (int) Params[4];
59+
Nsum = (int) Params[3];
60+
sigma = (float) Params[9];
61+
62+
if (tid<nsizes*NchanNear){
63+
d2 = dist[tid/nsizes + NchanNear * bidy];
64+
k = tid%nsizes;
65+
sA[tid] = expf( - (d2 * d2)/((1+k)*(1+k)*sigma*sigma));
66+
}
67+
__syncthreads();
68+
69+
tid0 = tid + bidx * blockDim.x;
70+
while (tid0<NT){
71+
Cmax = 0.0f;
72+
kmax = 0;
73+
74+
for (t=0;t<Nrank;t++){
75+
for(k=0; k<nsizes; k++)
76+
a[k] = 0.;
77+
78+
for(j=0; j<Nsum; j++){
79+
iChan = iC2[j + NchanNear * bidy];
80+
for(k=0; k<nsizes; k++)
81+
a[k] += sA[k + nsizes * j] *
82+
data[tid0 + NT * iChan + t * NT * Nchan];
83+
}
84+
for(k=0; k<nsizes; k++){
85+
a[k] = max(a[k], 0.);
86+
if (a[k]*a[k] / v2[k + nsizes*bidy] > Cmax){
87+
Cmax = a[k]*a[k]/v2[k + nsizes*bidy];
88+
kmax = t + k*Nrank;
89+
}
90+
}
91+
}
92+
datasum[tid0 + NT * bidy] = Cmax;
93+
kkmax[tid0 + NT * bidy] = kmax;
94+
95+
tid0 += blockDim.x * gridDim.x;
96+
}
97+
}
98+
99+
//////////////////////////////////////////////////////////////////////////////////////////
100+
__global__ void max1D(const double *Params, const float *data, float *conv_sig){
101+
102+
volatile __shared__ float sdata[Nthreads+81];
103+
float y, spkTh;
104+
int tid, tid0, bid, i, NT, nt0, nt0min;
105+
106+
NT = (int) Params[0];
107+
nt0 = (int) Params[2];
108+
nt0min = (int) Params[5];
109+
spkTh = (float) Params[6];
110+
111+
tid = threadIdx.x;
112+
bid = blockIdx.x;
113+
114+
tid0 = 0;
115+
while (tid0<NT-Nthreads-nt0+1){
116+
if (tid<nt0)
117+
sdata[tid] = data[tid0 + tid + NT*bid];
118+
sdata[tid + nt0] = data[nt0+tid0 + tid+ NT*bid];
119+
__syncthreads();
120+
121+
y = 0.0f;
122+
#pragma unroll 4
123+
for(i=0;i<2*nt0min;i++)
124+
y = max(y, sdata[tid+i]);
125+
126+
if (y>spkTh*spkTh)
127+
conv_sig[tid0 + 1*(nt0min) + tid + NT*bid] = y;
128+
129+
tid0+=Nthreads;
130+
__syncthreads();
131+
}
132+
}
133+
134+
//////////////////////////////////////////////////////////////////////////////////////////
135+
__global__ void maxChannels(const double *Params, const float *dataraw, const float *data,
136+
const int *iC, const int *iC2, const float *dist2, const int *kkmax,
137+
const float *dfilt, int *st, int *counter, float *cF){
138+
139+
int nt0, indx, tid, tid0, i, bid, NT, j,iChan, nt0min, Nrank, kfilt;
140+
int Nchan, NchanNear, NchanUp, NchanNearUp, bidy ;
141+
double Cf, d;
142+
float spkTh, d2;
143+
bool flag;
144+
145+
NT = (int) Params[0];
146+
Nchan = (int) Params[1];
147+
NchanNear = (int) Params[3];
148+
NchanUp = (int) Params[7];
149+
NchanNearUp = (int) Params[8];
150+
nt0 = (int) Params[2];
151+
nt0min = (int) Params[5];
152+
spkTh = (float) Params[6];
153+
Nrank = (int) Params[4];
154+
155+
tid = threadIdx.x;
156+
bid = blockIdx.x;
157+
bidy = blockIdx.y;
158+
159+
tid0 = tid + bid * blockDim.x;
160+
while (tid0<NT-nt0-nt0min){
161+
i = bidy;
162+
Cf = (double) data[tid0 + NT * i];
163+
flag = true;
164+
for(j=1; j<NchanNearUp; j++){
165+
if (dist2[j + NchanNearUp * i] < 100.){
166+
iChan = iC2[j+ NchanNearUp * i];
167+
if (data[tid0 + NT * iChan] > Cf){
168+
flag = false;
169+
break;
170+
}
171+
}
172+
}
173+
174+
if (flag){
175+
if (Cf>spkTh*spkTh){
176+
d = (double) dataraw[tid0+0 * (nt0min-1) + NT*i]; //
177+
if (d > Cf-1e-6){
178+
// this is a hit, atomicAdd and return spikes
179+
indx = atomicAdd(&counter[0], 1);
180+
if (indx<maxFR){
181+
st[0+4*indx] = tid0;
182+
st[1+4*indx] = i;
183+
st[2+4*indx] = sqrt(d);
184+
st[3+4*indx] = kkmax[tid0+0*(nt0min-1) + NT*i];
185+
kfilt = st[3+4*indx]%Nrank;
186+
for(j=0; j<NchanNear; j++){
187+
iChan = iC[j+ NchanNear * i];
188+
cF[j + NchanNear * indx] = dfilt[tid0+0*(nt0min-1) + NT * iChan + kfilt * Nchan*NT];
189+
}
190+
}
191+
}
192+
}
193+
}
194+
195+
tid0 += blockDim.x * gridDim.x;
196+
}
197+
}

pykilosort/datashift/datashift.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
ir.xc, ir.yc = probe.xc, probe.yc
2+
ir.ops = Bunch()
3+
4+
# The min and max of the y and x ranges of the channels
5+
ymin = min(ir.yc)
6+
ymax = max(ir.yc)
7+
xmin = min(ir.xc)
8+
xmax = max(ir.xc)
9+
10+
# Determine the average vertical spacing between channels.
11+
# Usually all the vertical spacings are the same, i.e. on Neuropixels probes.
12+
dmin = np.median(np.diff(np.unique(ir.yc)))
13+
print(f"pitch is {dmin} um\n")
14+
ir.ops.yup = np.arange(
15+
start=ymin, step=dmin / 2, stop=ymax
16+
) # centers of the upsampled y positions
17+
18+
# Determine the template spacings along the x dimension
19+
x_range = xmax - xmin
20+
npt = math.floor(
21+
x_range / 16
22+
) # this would come out as 16um for Neuropixels probes, which aligns with the geometry.
23+
ir.ops.xup = np.linspace(xmin, xmax, npt + 1) # centers of the upsampled x positions
24+
25+
spkTh = 10 # same as the usual "template amplitude", but for the generic templates
26+
27+
# Extract all the spikes across the recording that are captured by the
28+
# generic templates. Very few real spikes are missed in this way.
29+
st3 = standalone_detector(ir, spkTh)
30+
31+
# binning width across Y (um)
32+
dd = 5
33+
34+
# detected depths
35+
dep = st3[:, 2]
36+
37+
# min and max for the range of depths
38+
dmin = ymin - 1
39+
dep = dep - dmin
40+
41+
dmax = 1 + ceil(max(dep) / dd)
42+
Nbatches = ir.temp.Nbatch
43+
44+
# which batch each spike is coming from
45+
batch_id = st3[:, 5] # ceil[st3[:,1]/dt]
46+
47+
# preallocate matrix of counts with 20 bins, spaced logarithmically
48+
F = np.zeros(dmax, 20, Nbatches)
49+
for t in range(Nbatches):
50+
# find spikes in this batch
51+
ix = np.where(batch_id == t)
52+
53+
# subtract offset
54+
dep = st3[ix, 2] - dmin
55+
56+
# amplitude bin relative to the minimum possible value
57+
amp = log10(min(99, st3[ix, 3])) - log10(spkTh)
58+
59+
# normalization by maximum possible value
60+
amp = amp / (log10(100) - log10(spkTh))
61+
62+
# multiply by 20 to distribute a [0,1] variable into 20 bins
63+
# sparse is very useful here to do this binning quickly
64+
M = sparse(ceil(dep / dd), ceil(1e-5 + amp * 20), ones(numel(ix), 1), dmax, 20)
65+
66+
# the counts themselves are taken on a logarithmic scale (some neurons
67+
# fire too much!)
68+
F[:, :, t] = log2(1 + M)
69+
end
70+
71+
##
72+
# the 'midpoint' branch is for chronic recordings that have been
73+
# concatenated in the binary file
74+
# if isfield(ops, 'midpoint')
75+
# # register the first block as usual
76+
# [imin1, F1] = align_block(F(:, :, 1:ops.midpoint))
77+
# # register the second block as usual
78+
# [imin2, F2] = align_block(F(:, :, ops.midpoint+1:end))
79+
# # now register the average first block to the average second block
80+
# d0 = align_pairs(F1, F2)
81+
# # concatenate the shifts
82+
# imin = [imin1 imin2 + d0]
83+
# imin = imin - mean(imin)
84+
# ops.datashift = 1
85+
# else
86+
# # determine registration offsets
87+
# ysamp = dmin + dd * [1:dmax] - dd/2
88+
# [imin,yblk, F0] = align_block2(F, ysamp, ops.nblocks)
89+
# end
90+
91+
##
92+
if opts.get("fig", True):
93+
ax = plt.subplot()
94+
# plot the shift trace in um
95+
ax.plot(imin * dd)
96+
97+
ax = plt.subplot()
98+
# raster plot of all spikes at their original depths
99+
st_shift = st3[:, 2] # + imin(batch_id)' * dd
100+
for j in range(spkTh, 100):
101+
# for each amplitude bin, plot all the spikes of that size in the
102+
# same shade of gray
103+
ix = st3[:, 3] == j # the amplitudes are rounded to integers
104+
ax.plot(
105+
st3[ix, 1],
106+
st_shift[ix],
107+
".",
108+
"color",
109+
[max(0, 1 - j / 40) for i in range(3)],
110+
) # the marker color here has been carefully tuned
111+
plt.tight_layout()
112+
113+
# if we're creating a registered binary file for visualization in Phy
114+
if opts.get("fbinaryproc", False):
115+
with open(opts["fbinaryproc"], "w") as f:
116+
pass
117+
118+
# convert to um
119+
dshift = imin * dd
120+
# sort in case we still want to do "tracking"
121+
122+
_, ir.iorig = np.sort(np.mean(dshift, 2))
123+
124+
# sigma for the Gaussian process smoothing
125+
sig = ir.ops.sig
126+
# register the data batch by batch
127+
for ibatch in range(Nbatches):
128+
shift_batch_on_disk2(ir, ibatch, dshift[ibatch, :], yblk, sig)
129+
end
130+
fprintf("time #2.2f, Shifted up/down #d batches. \n", toc, Nbatches)
131+
132+
# keep track of dshift
133+
ir.dshift = dshift
134+
# keep track of original spikes
135+
ir.st0 = st3
136+
137+
# next, we can just run a normal spike sorter, like Kilosort1, and forget about the transformation that has happened in here

0 commit comments

Comments
 (0)