Skip to content

Commit b4a4949

Browse files
committed
Merge branch 'master' of github.com:rossant/pykilosort
* 'master' of github.com:rossant/pykilosort: Typed Configuration (#28) Comparison between MATLAB & Python Versions (#25) Add dockerfile for running tests with GPU (#22)
2 parents 5ffda8b + ccb67dc commit b4a4949

File tree

5 files changed

+95
-87
lines changed

5 files changed

+95
-87
lines changed

pykilosort/default_params.py

Lines changed: 0 additions & 78 deletions
This file was deleted.

pykilosort/gui.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from phy.gui.qt import QSlider, Qt, QLabel, QScrollArea, QVBoxLayout, QWidget
1818
from phy.gui import Actions
1919

20-
from .default_params import default_params
20+
from .params import KilosortParams
2121
from .main import run
2222

2323
logger = logging.getLogger(__name__)
@@ -168,7 +168,7 @@ def create_probe_view(self, gui):
168168
def create_params_widget(self, gui):
169169
"""Create the widget that allows to enter parameters for KS2."""
170170
widget = KeyValueWidget(gui)
171-
for name, default in default_params.items():
171+
for name, default in KilosortParams().dict().items():
172172
# HACK: None default params in KS2 are floats
173173
vtype = 'float' if default is None else None
174174
widget.add_pair(name, default, vtype=vtype)

pykilosort/main.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,16 @@
33
from pathlib import Path
44
from phylib.io.traces import get_ephys_reader
55

6-
from pprint import pprint
76
import numpy as np
7+
from pprint import pprint
8+
from pydantic import BaseModel
89

910
from .preprocess import preprocess, get_good_channels, get_whitening_matrix, get_Nbatch
1011
from .cluster import clusterSingleBatches
1112
from .learn import learnAndSolve8b
1213
from .postprocess import find_merges, splitAllClusters, set_cutoff, rezToPhy
1314
from .utils import Bunch, Context, memmap_large_array, load_probe
14-
from .default_params import default_params, set_dependent_params
15+
from .params import KilosortParams
1516

1617
logger = logging.getLogger(__name__)
1718

@@ -61,10 +62,8 @@ def run(
6162
assert probe
6263

6364
# Get params.
64-
user_params = params or {}
65-
params = default_params.copy()
66-
set_dependent_params(params)
67-
params.update(user_params)
65+
if not isinstance(params, BaseModel):
66+
params = KilosortParams(**params or {})
6867
assert params
6968

7069
# dir path

pykilosort/params.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import typing as t
2+
from math import ceil
3+
4+
from pydantic import BaseModel, Field
5+
6+
from .utils import Bunch
7+
# TODO: design - Let's move all of this to a yaml file with sections so that its easier to read.
8+
# - We can then just parse the yaml file to generate this.
9+
10+
11+
class KilosortParams(BaseModel):
12+
fs: float = Field(30000., description="sample rate")
13+
14+
fshigh: float = Field(150., description="high pass filter frequency")
15+
fslow: t.Optional[float] = Field(None, description="low pass filter frequency")
16+
minfr_goodchannels: float = Field(0.1, description="minimum firing rate on a 'good' channel (0 to skip)")
17+
18+
Th: t.List[float] = Field([10, 4], description="""
19+
threshold on projections (like in Kilosort1, can be different for last pass like [10 4])
20+
""")
21+
ThPre: float = Field(8, description="threshold crossings for pre-clustering (in PCA projection space)")
22+
23+
lam: float = Field(10, description="""
24+
how important is the amplitude penalty (like in Kilosort1, 0 means not used,
25+
10 is average, 50 is a lot)
26+
""")
27+
28+
AUCsplit: float = Field(0.9, description="""
29+
splitting a cluster at the end requires at least this much isolation for each sub-cluster (max=1)
30+
""")
31+
32+
minFR: float = Field(1. / 50, description="""
33+
minimum spike rate (Hz), if a cluster falls below this for too long it gets removed
34+
""")
35+
36+
momentum: t.List[float] = Field([20, 400], description="""
37+
number of samples to average over (annealed from first to second value)
38+
""")
39+
40+
sigmaMask: float = Field(30, description="""
41+
spatial constant in um for computing residual variance of spike
42+
""")
43+
44+
# danger, changing these settings can lead to fatal errors
45+
# options for determining PCs
46+
spkTh: float = Field(-6, description="spike threshold in standard deviations")
47+
reorder: int = Field(1, description="whether to reorder batches for drift correction.")
48+
nskip: int = Field(5, description="how many batches to skip for determining spike PCs")
49+
nSkipCov: int = Field(25, description="compute whitening matrix from every nth batch")
50+
51+
# GPU = 1 # has to be 1, no CPU version yet, sorry
52+
# Nfilt = 1024 # max number of clusters
53+
nfilt_factor: int = Field(4, description="max number of clusters per good channel (even temporary ones)")
54+
ntbuff = Field(64, description="""
55+
samples of symmetrical buffer for whitening and spike detection
56+
57+
Must be multiple of 32 + ntbuff. This is the batch size (try decreasing if out of memory).
58+
""")
59+
60+
whiteningRange: int = Field(32, description="number of channels to use for whitening each channel")
61+
nSkipCov: int = Field(25, description="compute whitening matrix from every N-th batch")
62+
scaleproc: int = Field(200, description="int16 scaling of whitened data")
63+
nPCs: int = Field(3, description="how many PCs to project the spikes into")
64+
65+
nt0: int = 61
66+
nup: int = 10
67+
sig: int = 1
68+
gain: int = 1
69+
70+
templateScaling: float = 20.0
71+
72+
loc_range: t.List[int] = [5, 4]
73+
long_range: t.List[int] = [30, 6]
74+
75+
Nfilt: t.Optional[int] = None # This should be a computed property once we add the probe to the config
76+
77+
# Computed properties
78+
@property
79+
def NT(self) -> int:
80+
return 64 * 1024 + self.ntbuff
81+
82+
@property
83+
def NTbuff(self) -> int:
84+
return self.NT + 4 * self.ntbuff
85+
86+
@property
87+
def nt0min(self) -> int:
88+
return int(ceil(20 * self.nt0 / 61))

pykilosort/preprocess.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,6 @@ def get_good_channels(raw_data=None, probe=None, params=None):
252252
minfr_goodchannels = params.minfr_goodchannels
253253

254254
chanMap = probe.chanMap
255-
# Nchan = probe.Nchan
256255
NchanTOT = len(chanMap)
257256

258257
ich = []

0 commit comments

Comments
 (0)