|
| 1 | +import typing as t |
| 2 | +from math import ceil |
| 3 | + |
| 4 | +from pydantic import BaseModel, Field |
| 5 | + |
| 6 | +from .utils import Bunch |
| 7 | +# TODO: design - Let's move all of this to a yaml file with sections so that its easier to read. |
| 8 | +# - We can then just parse the yaml file to generate this. |
| 9 | + |
| 10 | + |
| 11 | +class KilosortParams(BaseModel): |
| 12 | + fs: float = Field(30000., description="sample rate") |
| 13 | + |
| 14 | + fshigh: float = Field(150., description="high pass filter frequency") |
| 15 | + fslow: t.Optional[float] = Field(None, description="low pass filter frequency") |
| 16 | + minfr_goodchannels: float = Field(0.1, description="minimum firing rate on a 'good' channel (0 to skip)") |
| 17 | + |
| 18 | + Th: t.List[float] = Field([10, 4], description=""" |
| 19 | + threshold on projections (like in Kilosort1, can be different for last pass like [10 4]) |
| 20 | + """) |
| 21 | + ThPre: float = Field(8, description="threshold crossings for pre-clustering (in PCA projection space)") |
| 22 | + |
| 23 | + lam: float = Field(10, description=""" |
| 24 | + how important is the amplitude penalty (like in Kilosort1, 0 means not used, |
| 25 | + 10 is average, 50 is a lot) |
| 26 | + """) |
| 27 | + |
| 28 | + AUCsplit: float = Field(0.9, description=""" |
| 29 | + splitting a cluster at the end requires at least this much isolation for each sub-cluster (max=1) |
| 30 | + """) |
| 31 | + |
| 32 | + minFR: float = Field(1. / 50, description=""" |
| 33 | + minimum spike rate (Hz), if a cluster falls below this for too long it gets removed |
| 34 | + """) |
| 35 | + |
| 36 | + momentum: t.List[float] = Field([20, 400], description=""" |
| 37 | + number of samples to average over (annealed from first to second value) |
| 38 | + """) |
| 39 | + |
| 40 | + sigmaMask: float = Field(30, description=""" |
| 41 | + spatial constant in um for computing residual variance of spike |
| 42 | + """) |
| 43 | + |
| 44 | + # danger, changing these settings can lead to fatal errors |
| 45 | + # options for determining PCs |
| 46 | + spkTh: float = Field(-6, description="spike threshold in standard deviations") |
| 47 | + reorder: int = Field(1, description="whether to reorder batches for drift correction.") |
| 48 | + nskip: int = Field(5, description="how many batches to skip for determining spike PCs") |
| 49 | + nSkipCov: int = Field(25, description="compute whitening matrix from every nth batch") |
| 50 | + |
| 51 | + # GPU = 1 # has to be 1, no CPU version yet, sorry |
| 52 | + # Nfilt = 1024 # max number of clusters |
| 53 | + nfilt_factor: int = Field(4, description="max number of clusters per good channel (even temporary ones)") |
| 54 | + ntbuff = Field(64, description=""" |
| 55 | + samples of symmetrical buffer for whitening and spike detection |
| 56 | + |
| 57 | + Must be multiple of 32 + ntbuff. This is the batch size (try decreasing if out of memory). |
| 58 | + """) |
| 59 | + |
| 60 | + whiteningRange: int = Field(32, description="number of channels to use for whitening each channel") |
| 61 | + nSkipCov: int = Field(25, description="compute whitening matrix from every N-th batch") |
| 62 | + scaleproc: int = Field(200, description="int16 scaling of whitened data") |
| 63 | + nPCs: int = Field(3, description="how many PCs to project the spikes into") |
| 64 | + |
| 65 | + nt0: int = 61 |
| 66 | + nup: int = 10 |
| 67 | + sig: int = 1 |
| 68 | + gain: int = 1 |
| 69 | + |
| 70 | + templateScaling: float = 20.0 |
| 71 | + |
| 72 | + loc_range: t.List[int] = [5, 4] |
| 73 | + long_range: t.List[int] = [30, 6] |
| 74 | + |
| 75 | + Nfilt: t.Optional[int] = None # This should be a computed property once we add the probe to the config |
| 76 | + |
| 77 | + # Computed properties |
| 78 | + @property |
| 79 | + def NT(self) -> int: |
| 80 | + return 64 * 1024 + self.ntbuff |
| 81 | + |
| 82 | + @property |
| 83 | + def NTbuff(self) -> int: |
| 84 | + return self.NT + 4 * self.ntbuff |
| 85 | + |
| 86 | + @property |
| 87 | + def nt0min(self) -> int: |
| 88 | + return int(ceil(20 * self.nt0 / 61)) |
0 commit comments