diff --git a/.gitignore b/.gitignore index 6f6651146..f73ca14e3 100644 --- a/.gitignore +++ b/.gitignore @@ -89,3 +89,5 @@ ignored/ # neo logs **/logs + +pytest.ini \ No newline at end of file diff --git a/elephant/schemas/__init__.py b/elephant/schemas/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/elephant/schemas/field_serializer.py b/elephant/schemas/field_serializer.py new file mode 100644 index 000000000..d7d2587bd --- /dev/null +++ b/elephant/schemas/field_serializer.py @@ -0,0 +1,9 @@ +import quantities as pq + +def serialize_quantity(value: pq.Quantity) -> dict: + if value is None: + return None + return { + "value": value.magnitude, + "unit": value.dimensionality + } \ No newline at end of file diff --git a/elephant/schemas/field_validator.py b/elephant/schemas/field_validator.py new file mode 100644 index 000000000..6a61323c5 --- /dev/null +++ b/elephant/schemas/field_validator.py @@ -0,0 +1,247 @@ +import numpy as np +import quantities as pq +import neo +import elephant +from enum import Enum +from typing import Any +import warnings + +def get_length(obj) -> int: + """ + Return the length (number of elements) of various supported datatypes: + - list + - numpy.ndarray + - pq.Quantity + - neo.SpikeTrain + + Returns + ------- + int + The number of elements or spikes in the object. + + Raises + ------ + TypeError + If the object type is not supported. + """ + if obj is None: + raise ValueError("Cannot get length of None") + + if isinstance(obj, elephant.trials.Trials): + return len(obj.trials) + elif isinstance(obj, elephant.conversion.BinnedSpikeTrain): + return obj.n_bins + elif isinstance(obj, neo.SpikeTrain): + return len(obj) + elif isinstance(obj, pq.Quantity): + return obj.size + elif isinstance(obj, np.ndarray): + return obj.size + elif isinstance(obj, (list,tuple)): + return len(obj) + elif isinstance(obj, neo.core.spiketrainlist.SpikeTrainList): + return len(obj) + + + + else: + raise TypeError( + f"Unsupported type for length computation: {type(obj).__name__}" + ) + +def is_sorted(obj) -> bool: + if obj is None: + raise ValueError("Cannot check sortedness of None") + + if isinstance(obj, (list, np.ndarray, pq.Quantity)): + arr = np.asarray(obj) + return np.all(arr[:-1] <= arr[1:]) + elif isinstance(obj, neo.SpikeTrain): + arr = obj.magnitude # Get the underlying numpy array of spike times + return np.all(arr[:-1] <= arr[1:]) + return False + +def is_matrix(obj) -> bool: + if obj is None: + raise ValueError("Cannot check matrix of None") + if isinstance(obj, (list, np.ndarray, pq.Quantity)): + arr = np.asarray(obj) + return arr.ndim >= 2 + elif isinstance(obj, neo.SpikeTrain): + arr = obj.magnitude # Get the underlying numpy array of spike times + return arr.ndim >= 2 + return False + +def validate_covariance_matrix_rank_deficient(obj, info): + """ + Check if the covariance matrix of the given object is rank deficient. + Should work for elephant.trials.Trials, list of neo.core.spiketrainlist.SpikeTrainList or list of list of neo.core.SpikeTrain. + """ + return obj + +def validate_type( + value, + info, + allowed_types: tuple, + allow_none: bool, +): + """Generic type validation helper.""" + if value is None: + if allow_none: + return value + raise ValueError(f"{info.field_name} cannot be None") + + if not isinstance(value, allowed_types): + raise TypeError(f"{info.field_name} must be one of {allowed_types}, not {type(value).__name__}") + return value + +def validate_length( + value, + info: str, + min_length: int, + warning: bool +): + if min_length>0: + if get_length(value) < min_length: + if warning: + warnings.warn(f"{info.field_name} has less than {min_length} elements", UserWarning) + else: + raise ValueError(f"{info.field_name} must contain at least {min_length} elements") + return value + +def validate_type_length(value, info, allowed_types: tuple, allow_none: bool, min_length: int, warning: bool = False): + validate_type(value, info, allowed_types, allow_none) + if value is not None: + validate_length(value, info, min_length, warning) + return value + +def validate_array_content(value, info, allowed_types: tuple, allow_none: bool, min_length: int, allowed_content_types: tuple, min_length_content: int = 0): + validate_type_length(value, info, allowed_types, allow_none, min_length) + hasContentLength = False + for i, item in enumerate(value): + if not isinstance(item, allowed_content_types): + raise TypeError(f"Element {i} in {info.field_name} must be {allowed_content_types}, not {type(item).__name__}") + if min_length_content > 0 and get_length(item) >= min_length_content: + hasContentLength = True + if(min_length_content > 0 and not hasContentLength): + raise ValueError(f"{info.field_name} must contain at least one element with at least {min_length_content} elements") + + return value + +# ---- Specialized validation helpers ---- + +def validate_spiketrain(value, info, allowed_types=(list, neo.SpikeTrain, pq.Quantity, np.ndarray), allow_none = False, min_length = 1, check_sorted = False): + validate_type_length(value, info, allowed_types, allow_none, min_length) + if(check_sorted): + if value is not None and not is_sorted(value): + warnings.warn(f"{info.field_name} is not sorted", UserWarning) + if(isinstance(value, neo.SpikeTrain)): + if value.t_start is not None and value.t_stop is not None: + if value.t_start > value.t_stop: + raise ValueError(f"{info.field_name} has t_start > t_stop") + return value + +def validate_spiketrains(value, info, allowed_types = (list,), allow_none = False, min_length = 1, allowed_content_types = (list, neo.SpikeTrain, pq.Quantity, np.ndarray), min_length_content = 0): + validate_array_content(value, info, allowed_types, allow_none, min_length, allowed_content_types, min_length_content) + return value + +def validate_spiketrains_matrix(value, info, allowed_types = (elephant.trials.Trials, list[neo.core.spiketrainlist.SpikeTrainList], list[list[neo.core.SpikeTrain]]), allow_none = False, min_length = 1, check_rank_deficient = False): + if isinstance(value, list): + validate_spiketrains(value, info, allowed_content_types=(neo.core.spiketrainlist.SpikeTrainList,list[neo.core.SpikeTrain],)) + else: + validate_type(value, info, (elephant.trials.Trials,), allow_none=False) + if check_rank_deficient: + return validate_covariance_matrix_rank_deficient(value, info) + return value + +def validate_time(value, info, allowed_types=(float, pq.Quantity) ,allow_none=True): + if(isinstance(value, np.ndarray) and value.size==1): + value = value.item() + + validate_type(value, info, allowed_types, allow_none) + return value + +def validate_quantity(value, info, allow_none=False): + validate_type(value, info, (pq.Quantity,), allow_none) + return value + +def validate_time_intervals(value, info, allowed_types = (list, pq.Quantity, np.ndarray), allow_none = False, min_length=0, check_matrix = False): + validate_type_length(value, info, allowed_types, allow_none, min_length) + if check_matrix: + if value is not None and is_matrix(value): + raise ValueError(f"{info.field_name} is not allowed to be a matrix") + return value + +def validate_array(value, info, allowed_types=(list, np.ndarray, tuple) , allow_none=False, min_length=1, allowed_content_types = None, min_length_content = 0): + if allowed_content_types is None: + validate_type_length(value, info, allowed_types, allow_none, min_length) + else: + validate_array_content(value, info, allowed_types, allow_none, min_length, allowed_content_types, min_length_content) + return value + +def validate_binned_spiketrain(value, info, allowed_types=(elephant.conversion.BinnedSpikeTrain,), allow_none=False, min_length=1): + validate_type_length(value, info, allowed_types, allow_none, min_length, warning=True) + if value is not None and isinstance(value, elephant.conversion.BinnedSpikeTrain): + spmat = value.sparse_matrix + + # Check for empty spike trains + n_spikes_per_row = spmat.sum(axis=1) + if n_spikes_per_row.min() == 0: + warnings.warn( + f'Detected empty spike trains (rows) in the {info.field_name}.', UserWarning) + return value + +def validate_dict_enum_types(value : dict[Enum, Any], info, typeDictionary: dict[Enum, type]): + for key, val in value.items(): + if not isinstance(val, typeDictionary[key]): + raise TypeError(f"Value for key {key} in {info.field_name} must be of type {typeDictionary[key].__name__}, not {type(val).__name__}") + return value + +def validate_key_in_tuple(value : str, info, t: tuple): + if value not in t: + raise ValueError(f"{info}:{value} is not in the options {t}") + return value + + +# ---- Model validation helpers ---- + +def model_validate_spiketrains_same_t_start_stop(spiketrains, t_start, t_stop, name: str = "spiketrains", warning: bool = False): + if(t_start is None or t_stop is None): + first = True + for i, item in enumerate(spiketrains): + if first: + t_start = item.t_start + t_stop = item.t_stop + first = False + else: + if t_start is None and item.t_start != t_start: + if warning: + warnings.warn(f"{name} has different t_start values among its elements", UserWarning) + else: + raise ValueError(f"{name} has different t_start values among its elements") + if t_stop is None and item.t_stop != t_stop: + if warning: + warnings.warn(f"{name} has different t_stop values among its elements", UserWarning) + else: + raise ValueError(f"{name} has different t_stop values among its elements") + else: + if t_start>t_stop: + raise ValueError(f"{name} has t_start > t_stop") + +def model_validate_two_spiketrains_same_t_start_stop(spiketrain_i, spiketrain_j): + if spiketrain_i.t_start != spiketrain_j.t_start: + raise ValueError("spiketrain_i and spiketrain_j need to have the same t_start") + if spiketrain_i.t_stop != spiketrain_j.t_stop: + raise ValueError("spiketrain_i and spiketrain_j need to have the same t_stop") + +def model_validate_time_intervals_with_nan(time_intervals , with_nan, name: str = "time_intervals"): + if get_length(time_intervals)<2: + if(with_nan): + warnings.warn(f"{name} has less than two entries so a np.Nan will be generated", UserWarning) + else: + raise ValueError(f"{name} has less than two entries") + +def model_validate_binned_spiketrain_fast(binned_spiketrain, fast, name: str = "binned_spiketrain"): + if(fast and np.max(binned_spiketrain.shape) > np.iinfo(np.int32).max): + raise MemoryError(f"{name} is too large for fast=True option") + \ No newline at end of file diff --git a/elephant/schemas/function_validator.py b/elephant/schemas/function_validator.py new file mode 100644 index 000000000..e388339a4 --- /dev/null +++ b/elephant/schemas/function_validator.py @@ -0,0 +1,38 @@ +from functools import wraps +from inspect import signature +from pydantic import BaseModel + +skip_validation = False + +def validate_with(model_class: type[BaseModel]): + """ + A decorator that validates the inputs of a function using a Pydantic model. + Works for both positional and keyword arguments. + """ + def decorator(func): + sig = signature(func) + + @wraps(func) + def wrapper(*args, **kwargs): + + if not skip_validation: + # Bind args & kwargs to function parameters + bound = sig.bind_partial(*args, **kwargs) + bound.apply_defaults() + data = bound.arguments + + # Validate using Pydantic + model_class(**data) + + # Call function + return func(*args, **kwargs) + return wrapper + return decorator + +def activate_validation(): + global skip_validation + skip_validation = False + +def deactivate_validation(): + global skip_validation + skip_validation = True \ No newline at end of file diff --git a/elephant/schemas/schema_spike_train_correlation.py b/elephant/schemas/schema_spike_train_correlation.py new file mode 100644 index 000000000..a2153cd0e --- /dev/null +++ b/elephant/schemas/schema_spike_train_correlation.py @@ -0,0 +1,153 @@ +import quantities as pq +import numpy as np +from typing import ( + Any, + Union, + Optional +) +from pydantic import ( + BaseModel, + Field, + field_validator, + model_validator, + field_serializer +) +import neo +from enum import Enum + +import elephant.schemas.field_validator as fv +import elephant.schemas.field_serializer as fs + +class PydanticCovariance(BaseModel): + """ + PyDantic Class to wrap the elephant.spike_train_correlation.covariance function + with additional type checking and json_schema by PyDantic. + """ + + binned_spiketrain: Any = Field(..., description="Binned spike train") + binary: Optional[bool] = Field(False, description="Use binary binned vectors") + fast: Optional[bool] = Field(True, description="Use faster implementation") + + @field_validator("binned_spiketrain") + @classmethod + def validate_binned_spiketrain(cls, v, info): + return fv.validate_binned_spiketrain(v, info) + + @model_validator(mode="after") + def validate_model(self): + fv.model_validate_binned_spiketrain_fast(self.binned_spiketrain, self.fast) + return self + + +class PydanticCorrelationCoefficient(BaseModel): + """ + PyDantic Class to wrap the elephant.spike_train_correlation.correlation_coefficient function + with additional type checking and json_schema by PyDantic. + """ + + binned_spiketrain: Any = Field(..., description="Binned spike train") + binary: Optional[bool] = Field(False, description="Use binary binned vectors") + fast: Optional[bool] = Field(True, description="Use faster implementation") + + @field_validator("binned_spiketrain") + @classmethod + def validate_binned_spiketrain(cls, v, info): + return fv.validate_binned_spiketrain(v, info) + + @model_validator(mode="after") + def validate_model(self): + fv.model_validate_binned_spiketrain_fast(self.binned_spiketrain, self.fast) + return self + + +class PydanticCrossCorrelationHistogram(BaseModel): + """ + PyDantic Class to wrap the elephant.spike_train_correlation.cross_correlation_histogram function + with additional type checking and json_schema by PyDantic. + """ + + class WindowOptions(Enum): + full = "full" + valid = "valid" + + class MethodOptions(Enum): + speed = "speed" + memory = "memory" + + binned_spiketrain_i: Any = Field(..., description="Binned spike train i") + binned_spiketrain_j: Any = Field(..., description="Binned spike train j") + window: Optional[Union[WindowOptions, list[int]]] = Field(WindowOptions.full, description="Window") + border_correction: Optional[bool] = Field(False, description="Correct border effect") + binary: Optional[bool] = Field(False, description="Count spike falling same bin as one") + kernel: Optional[Any] = Field(None, description="array containing a smoothing kernel") + method: Optional[MethodOptions] = Field(MethodOptions.speed, description="Method of calculating") + cross_correlation_coefficient: Optional[bool] = Field(False, description="Normalize CCH") + + @field_validator("binned_spiketrain_i", "binned_spiketrain_j") + @classmethod + def validate_binned_spiketrain(cls, v, info): + return fv.validate_binned_spiketrain(v, info) + + @field_validator("kernel") + @classmethod + def validate_kernel(cls, v, info): + return fv.validate_array(v, info, allowed_types=(np.ndarray,), allow_none=True) + + +class PydanticSpikeTimeTilingCoefficient(BaseModel): + """ + PyDantic Class to wrap the elephant.spike_train_correlation.spike_time_tiling_coefficient function + with additional type checking and json_schema by PyDantic. + """ + + spiketrain_i: Any = Field(..., description="Spike train Object i") + spiketrain_j: Any = Field(..., description="Spike train Object j (same T_start and same t_stop)") + dt: Any = Field(default_factory=lambda: 0.005 * pq.s, description="Synchronicity window") + + @field_serializer("dt", mode='plain') + def serialize_quantity(self, value: pq.Quantity): + return fs.serialize_quantity(value) + + @field_validator("spiketrain_i", "spiketrain_j") + @classmethod + def validate_spiketrain(cls, v, info): + # require specifically neo.core.SpikeTrain for this validator + return fv.validate_spiketrain(v, info, allowed_types=(neo.core.SpikeTrain,)) + + @field_validator("dt") + @classmethod + def validate_dt(cls, v, info): + return fv.validate_quantity(v, info) + + @model_validator(mode="after") + def check_correctTypeCombination(self): + fv.model_validate_two_spiketrains_same_t_start_stop(self.spiketrain_i, self.spiketrain_j) + return self + + +class PydanticSpikeTrainTimescale(BaseModel): + """ + PyDantic Class to wrap the elephant.spike_train_correlation.spike_train_timescale function + with additional type checking and json_schema by PyDantic. + """ + + binned_spiketrain: Any = Field(..., description="Binned spike train") + max_tau: Any = Field(..., description="Maximal integration time") + + @field_validator("binned_spiketrain") + @classmethod + def validate_binned_spiketrain(cls, v, info): + return fv.validate_binned_spiketrain(v, info) + + @field_validator("max_tau") + @classmethod + def validate_max_tau(cls, v, info): + return fv.validate_quantity(v, info) + + @model_validator(mode="after") + def check_correctTypeCombination(self): + if self.max_tau % self.binned_spiketrain.bin_size > 0.00001: + raise ValueError("max_tau has to be a multiple of binned_spiketrain.bin_size") + return self + + diff --git a/elephant/schemas/schema_spike_train_synchrony.py b/elephant/schemas/schema_spike_train_synchrony.py new file mode 100644 index 000000000..4c967f874 --- /dev/null +++ b/elephant/schemas/schema_spike_train_synchrony.py @@ -0,0 +1,64 @@ +import quantities as pq +import numpy as np +from typing import ( + Any, + Optional +) +from pydantic import ( + BaseModel, + Field, + field_validator, + field_serializer +) +import neo +from enum import Enum +from elephant.schemas.schema_statistics import PydanticComplexityInit + +import elephant.schemas.field_validator as fv +import elephant.schemas.field_serializer as fs + + +class PydanticSpikeContrast(BaseModel): + """ + PyDantic Class to wrap the elephant.spike_train_synchrony.spike_contrast function + with additional type checking and json_schema by PyDantic. + """ + + spiketrains: list = Field(..., description="List of Spiketrains") + t_start: Optional[Any] = Field(None, description="Start time") + t_stop: Optional[Any] = Field(None, description="Stop time") + min_bin: Optional[Any] = Field(default_factory=lambda: 10. * pq.ms, description="Min value for bin_min") + bin_shrink_factortime: Optional[float] = Field(0.9, description="Shrink bin size multiplier", ge=0., le=1.) + return_trace: Optional[bool] = Field(False, description="Return history of spike-contrast synchrony") + + @field_serializer("min_bin", mode='plain') + def serialize_quantity(self, value: pq.Quantity): + return fs.serialize_quantity(value) + + @field_validator("spiketrains") + @classmethod + def validate_spiketrains(cls, v, info): + return fv.validate_spiketrains(v, info, allowed_content_types=(neo.SpikeTrain,), min_length=2, min_length_content=2) + + @field_validator("t_start", "t_stop") + @classmethod + def validate_time(cls, v, info): + return fv.validate_quantity(v, info, allow_none=True) + + @field_validator("min_bin") + @classmethod + def validate_min_bin(cls, v, info): + return fv.validate_quantity(v, info) + + +class PydanticSynchrotoolInit(PydanticComplexityInit): + pass + +class PydanticSynchrotoolDeleteSynchrofacts(BaseModel): + class ModeOptions(Enum): + delete = "delete" + extract = "extract" + + threshold: int = Field(..., gt=1, description="Threshold for deletion of spikes") + in_place: Optional[bool] = Field(False, description="Make modification in place") + mode: Optional[ModeOptions] = Field(ModeOptions.delete, description="Inversion of mask for deletion") \ No newline at end of file diff --git a/elephant/schemas/schema_statistics.py b/elephant/schemas/schema_statistics.py new file mode 100644 index 000000000..9d79644da --- /dev/null +++ b/elephant/schemas/schema_statistics.py @@ -0,0 +1,347 @@ +import quantities as pq +import numpy as np +from typing import ( + Any, + Union, + Optional +) +from pydantic import ( + BaseModel, + Field, + field_validator, + model_validator, + field_serializer +) +import neo +from enum import Enum +import elephant + +from elephant.kernels import Kernel +import elephant.schemas.field_validator as fv +import elephant.schemas.field_serializer as fs + +import warnings + +class PydanticMeanFiringRate(BaseModel): + """ + PyDantic Class to wrap the elephant.statistics.mean_firing_rate function + with additional type checking and json_schema by PyDantic. + """ + spiketrain: Any = Field(None, description="SpikeTrain Object") + t_start: Optional[Any] = Field(None, description="Start time") + t_stop: Optional[Any] = Field(None, description="Stop time") + axis: Optional[int] = Field(None, description="Axis of calculation") + + @field_validator("spiketrain") + @classmethod + def validate_spiketrain(cls, v, info): + return fv.validate_spiketrain(v, info, allow_none=True) + + @field_validator("t_start", "t_stop") + @classmethod + def validate_time(cls, v, info): + return fv.validate_time(v, info) + + @model_validator(mode="after") + def validate_model(self): + if isinstance(self.spiketrain, (np.ndarray, list)): + if isinstance(self.t_start, pq.Quantity) or isinstance(self.t_stop, pq.Quantity): + raise TypeError("spiketrain is a np.ndarray or list but t_start or t_stop is pq.Quantity") + elif not (isinstance(self.t_start, pq.Quantity) and isinstance(self.t_stop, pq.Quantity)): + raise TypeError("spiketrain is a neo.SpikeTrain or pq.Quantity but t_start or t_stop is not pq.Quantity") + return self + +class PydanticInstantaneousRate(BaseModel): + """ + PyDantic Class to wrap the elephant.statistics.instantaneous_rate function + with additional type checking and json_schema by PyDantic. + """ + + class KernelOptions(Enum): + auto = "auto" + + spiketrains: Any = Field(..., description="Input spike train(s)") + sampling_period: Any = Field(..., gt=0, description="Time stamp resolution of spike times") + kernel: Union[KernelOptions, Any] = Field(KernelOptions.auto, description="Kernel for convolution") + cutoff: Optional[float] = Field(5.0, ge=0, description="cutoff of probability distribution") + t_start: Optional[Any] = Field(None, description="Start time") + t_stop: Optional[Any] = Field(None, description="Stop time") + trim: Optional[bool] = Field(False, description="Only return region of convolved signal") + center_kernel: Optional[bool] = Field(True, description="Center the kernel on spike") + border_correction: Optional[bool] = Field(False, description="Apply border correction") + pool_trials: Optional[bool] = Field(False, description="Calc firing rates averaged over trials when spiketrains is Trials object") + pool_spike_trains: Optional[bool] = Field(False, description="Calc firing rates averaged over spiketrains") + + + @field_validator("spiketrains") + @classmethod + def validate_spiketrains(cls, v, info): + if(isinstance(v, (list, neo.core.spiketrainlist.SpikeTrainList))): + return fv.validate_spiketrains(v, info, allowed_types=(list, neo.core.spiketrainlist.SpikeTrainList), allowed_content_types=(neo.SpikeTrain,)) + if(isinstance(v, neo.SpikeTrain)): + return fv.validate_spiketrain(v, info, allowed_types=(neo.SpikeTrain,)) + return fv.validate_spiketrains_matrix(v, info) + + @field_validator("sampling_period") + @classmethod + def validate_quantity(cls, v, info): + return fv.validate_quantity(v, info) + + @field_validator("kernel") + @classmethod + def validate_kernel(cls, v, info): + if v == cls.KernelOptions.auto.value: + return v + return fv.validate_type(v, info, allowed_types=(Kernel), allow_none=False) + + @field_validator("t_start", "t_stop") + @classmethod + def validate_time(cls, v, info): + return fv.validate_quantity(v, info, allow_none=True) + + @model_validator(mode="after") + def validate_model(self): + if(isinstance(self.kernel, Kernel) and self.cutoff < self.kernel.min_cutoff): + warnings.warn(f"cutoff {self.cutoff} is smaller than the minimum cutoff {self.kernel.min_cutoff} of the kernel", UserWarning) + if isinstance(self.spiketrains, list): + fv.model_validate_spiketrains_same_t_start_stop(self.spiketrains, self.t_start, self.t_stop, warning=True) + return self + +class PydanticTimeHistogram(BaseModel): + """ + PyDantic Class to wrap the elephant.statistics.time_histogram function + with additional type checking and json_schema by PyDantic. + """ + + class OutputOptions(Enum): + counts = "counts" + mean = "mean" + rate = "rate" + + spiketrains: list = Field(..., description="List of Spiketrains") + bin_size: Any = Field(..., description="Width histogram's time bins") + t_start: Optional[Any] = Field(None, description="Start time") + t_stop: Optional[Any] = Field(None, description="Stop time") + output: Optional[OutputOptions] = Field(OutputOptions.counts, description="Normalization") + binary: Optional[bool] = Field(False, description="To binary") + + @field_validator("spiketrains") + @classmethod + def validate_spiketrains(cls, v, info): + return fv.validate_spiketrains(v, info, allowed_content_types=(neo.SpikeTrain,)) + + @field_validator("bin_size") + @classmethod + def validate_quantity(cls, v, info): + return fv.validate_quantity(v, info) + + @field_validator("t_start", "t_stop") + @classmethod + def validate_quantity_none(cls, v, info): + return fv.validate_quantity(v, info, allow_none=True) + + @model_validator(mode="after") + def validate_model(self): + fv.model_validate_spiketrains_same_t_start_stop(self.spiketrains, self.t_start, self.t_stop, warning=True) + return self + +class PydanticOptimalKernelBandwidth(BaseModel): + """ + PyDantic Class to wrap the elephant.statistics.optimal_kernel_bandwidth function + with additional type checking and json_schema by PyDantic. + """ + + spiketimes: Any = Field(..., description="Sequence of spike times(ASC)") + times: Optional[Any] = Field(None, description="Time at which kernel bandwidth") + bandwidth: Optional[Any] = Field(None, description="Vector of kernal bandwidth") + bootstrap: Optional[bool] = Field(False, description="Use Bootstrap") + + @field_validator("spiketimes") + @classmethod + def validate_ndarray(cls, v, info): + return fv.validate_array(v, info, allowed_types=(np.ndarray,)) + + @field_validator("times", "bandwidth") + @classmethod + def validate_ndarray_none(cls, v, info): + return fv.validate_array(v, info, allowed_types=(np.ndarray,), allow_none=True) + +class PydanticIsi(BaseModel): + """ + PyDantic Class to wrap the elephant.statistics.isi function + with additional type checking and json_schema by PyDantic. + """ + spiketrain: Any = Field(..., description="SpikeTrain Object (sorted)") + axis: Optional[int] = Field(-1, description="Difference Axis") + + @field_validator("spiketrain") + @classmethod + def validate_spiketrain_sorted(cls, v, info): + return fv.validate_spiketrain(v, info, check_sorted=True) + +class PydanticCv(BaseModel): + """ + PyDantic Class to wrap the elephant.statistics.cv function + with additional type checking and json_schema by PyDantic. + """ + class NanPolicyOptions(Enum): + propagate = "propagate" + omit = "omit" + _raise = "raise" + + args: Any = Field(..., description="Input array") + axis: Union[int, None] = Field(0, description="Compute statistic axis") + nan_policy: NanPolicyOptions = Field(NanPolicyOptions.propagate, description="How handle input NaNs") + ddof: Optional[int] = Field(0, ge=0, description="Delta Degrees Of Freedom") + keepdims: Optional[bool] = Field(False, description="leave reduced axes in one-dimensional result") + + @field_validator("args") + @classmethod + def validate_array(cls, v, info): + return fv.validate_array(v, info) + +class PydanticCv2(BaseModel): + """ + PyDantic Class to wrap the elephant.statistics.cv2 function + with additional type checking and json_schema by PyDantic. + """ + + time_intervals: Any = Field(..., description="Vector of time intervals") + with_nan: Optional[bool] = Field(False, description="Do not Raise warning on short spike train") + + @field_validator("time_intervals") + @classmethod + def validate_time_intervals(cls, v, info): + return fv.validate_time_intervals(v, info, check_matrix=True) + + @model_validator(mode="after") + def validate_model(self): + fv.model_validate_time_intervals_with_nan(self.time_intervals, self.with_nan) + return self + +class PydanticLv(BaseModel): + """ + PyDantic Class to wrap the elephant.statistics.lv function + with additional type checking and json_schema by PyDantic. + """ + + time_intervals: Any = Field(..., description="Vector of time intervals") + with_nan: Optional[bool] = Field(False, description="Do not Raise warning on short spike train") + + @field_validator("time_intervals") + @classmethod + def validate_time_intervals(cls, v, info): + return fv.validate_time_intervals(v, info, check_matrix=True) + + @model_validator(mode="after") + def validate_model(self): + fv.model_validate_time_intervals_with_nan(self.time_intervals, self.with_nan) + return self + +class PydanticLvr(BaseModel): + """ + PyDantic Class to wrap the elephant.statistics.lvr function + with additional type checking and json_schema by PyDantic. + """ + + time_intervals: Any = Field(..., description="Vector of time intervals (default units: ms)") + R: Any = Field(default_factory=lambda: 5. * pq.ms, ge=0, description="Refractoriness constant (default quantity: ms)") + with_nan: Optional[bool] = Field(False, description="Do not Raise warning on short spike train") + + @field_serializer("R", mode='plain') + def serialize_quantity(self, v): + return fs.serialize_quantity(v) + + @field_validator("time_intervals") + @classmethod + def validate_time_intervals(cls, v, info): + return fv.validate_time_intervals(v, info, check_matrix=True) + + @field_validator("R") + @classmethod + def validate_R(cls, v, info): + fv.validate_type(v, info, (pq.Quantity, int, float), allow_none=False) + if(not isinstance(v, pq.Quantity)): + warnings.warn("R does not have any units so milliseconds are assumed", UserWarning) + return v + + @model_validator(mode="after") + def validate_model(self): + fv.model_validate_time_intervals_with_nan(self.time_intervals, self.with_nan) + return self + +class PydanticFanofactor(BaseModel): + """ + PyDantic Class to wrap the elephant.statistics.fanofactor function + with additional type checking and json_schema by PyDantic. + """ + spiketrains: list = Field(..., description="List of Spiketrains") + warn_tolerance: Any = Field(default_factory=lambda: 0.1 * pq.ms, ge=0, description="Warn tolerence of variations") + + @field_serializer("warn_tolerance", mode='plain') + def serialize_quantity(self, v): + return fs.serialize_quantity(v) + + @field_validator("spiketrains") + @classmethod + def validate_spiketrains(cls, v, info): + return fv.validate_spiketrains(v, info, min_length=0) + + @field_validator("warn_tolerance") + @classmethod + def validate_quantity(cls, v, info): + return fv.validate_quantity(v, info) + +class PydanticComplexityPdf(BaseModel): + """ + PyDantic Class to wrap the elephant.statistics.complexity_pdf function + with additional type checking and json_schema by PyDantic. + """ + spiketrains: list = Field(..., description="List of Spiketrains") + bin_size: Any = Field(..., description="Width histogram's time bins") + + @field_validator("spiketrains") + @classmethod + def validate_spiketrains(cls, v, info): + fv.validate_spiketrains(v, info, allowed_content_types=(neo.SpikeTrain,)) + fv.model_validate_spiketrains_same_t_start_stop(v, None, None) + return v + + @field_validator("bin_size") + @classmethod + def validate_quantity(cls, v, info): + return fv.validate_quantity(v, info) + +class PydanticComplexityInit(BaseModel): + spiketrains: list = Field(..., description="List of neo.SpikeTrain objects with common t_start/t_stop") + sampling_rate: Optional[Any] = Field(None, description="Sampling rate (1/time)") + bin_size: Optional[Any] = Field(None, description="Width of histogram bins") + binary: Optional[bool] = Field(True, description="If True count neurons, else total spikes") + spread: Optional[int] = Field(0, ge=0, description="Number of bins for synchronous spikes (>=0)") + tolerance: Optional[float] = Field(1e-8, description="Tolerance for rounding errors") + + @field_validator("spiketrains") + @classmethod + def validate_spiketrains(cls, v, info): + fv.validate_spiketrains(v, info, allowed_content_types=(neo.SpikeTrain,)) + fv.model_validate_spiketrains_same_t_start_stop(v, None, None) + return v + + @field_validator("bin_size") + @classmethod + def validate_bin_size(cls, v, info): + return fv.validate_quantity(v, info, allow_none=True) + + @field_validator("sampling_rate") + @classmethod + def validate_sampling_rate(cls, v, info): + fv.validate_quantity(v, info, allow_none=True) + if v is None: + warnings.warn("no sampling rate is supplied. This may lead to rounding errors when using the epoch to slice spike trains", UserWarning) + return v + + @model_validator(mode="after") + def check_rate_or_bin(self): + if self.sampling_rate is None and self.bin_size is None: + raise ValueError("Either sampling_rate or bin_size must be set") + return self \ No newline at end of file diff --git a/elephant/spike_train_correlation.py b/elephant/spike_train_correlation.py index 1d7cf0656..6bb03641f 100644 --- a/elephant/spike_train_correlation.py +++ b/elephant/spike_train_correlation.py @@ -25,6 +25,9 @@ from scipy import integrate from elephant.utils import check_neo_consistency +from elephant.schemas.function_validator import validate_with +from elephant.schemas.schema_spike_train_correlation import * + __all__ = [ "covariance", @@ -276,6 +279,7 @@ def kernel_smoothing(self, cross_corr_array, kernel): return np.convolve(cross_corr_array, kernel, mode='same') +@validate_with(PydanticCovariance) def covariance(binned_spiketrain, binary=False, fast=True): r""" Calculate the NxN matrix of pairwise covariances between all combinations @@ -376,6 +380,7 @@ def covariance(binned_spiketrain, binary=False, fast=True): binned_spiketrain, corrcoef_norm=False) +@validate_with(PydanticCorrelationCoefficient) def correlation_coefficient(binned_spiketrain, binary=False, fast=True): r""" Calculate the NxN matrix of pairwise Pearson's correlation coefficients @@ -549,6 +554,7 @@ def _covariance_sparse(binned_spiketrain, corrcoef_norm): return res +@validate_with(PydanticCrossCorrelationHistogram) def cross_correlation_histogram( binned_spiketrain_i, binned_spiketrain_j, window='full', border_correction=False, binary=False, kernel=None, method='speed', @@ -818,6 +824,7 @@ def cross_correlation_histogram( cch = cross_correlation_histogram +@validate_with(PydanticSpikeTimeTilingCoefficient) def spike_time_tiling_coefficient(spiketrain_i: neo.core.SpikeTrain, spiketrain_j: neo.core.SpikeTrain, dt: pq.Quantity = 0.005 * pq.s) -> float: @@ -992,6 +999,7 @@ def run_t(spiketrain: neo.core.SpikeTrain, dt: pq.Quantity = dt) -> float: sttc = spike_time_tiling_coefficient +@validate_with(PydanticSpikeTrainTimescale) def spike_train_timescale(binned_spiketrain, max_tau): r""" Calculates the auto-correlation time of a binned spike train; uses the diff --git a/elephant/spike_train_synchrony.py b/elephant/spike_train_synchrony.py index 946a24ae2..0e27cb4f7 100644 --- a/elephant/spike_train_synchrony.py +++ b/elephant/spike_train_synchrony.py @@ -29,6 +29,9 @@ from elephant.statistics import Complexity from elephant.utils import is_time_quantity, check_same_units +from elephant.schemas.function_validator import validate_with +from elephant.schemas.schema_spike_train_synchrony import * + SpikeContrastTrace = namedtuple("SpikeContrastTrace", ( "contrast", "active_spiketrains", "synchrony", "bin_size")) @@ -69,6 +72,7 @@ def _binning_half_overlap(spiketrain, edges): return histogram +@validate_with(PydanticSpikeContrast) def spike_contrast(spiketrains, t_start=None, t_stop=None, min_bin=10 * pq.ms, bin_shrink_factor=0.9, return_trace=False): @@ -261,6 +265,7 @@ class Synchrotool(Complexity): """ + @validate_with(PydanticSynchrotoolInit) def __init__(self, spiketrains, sampling_rate, bin_size=None, @@ -277,6 +282,7 @@ def __init__(self, spiketrains, spread=spread, tolerance=tolerance) + @validate_with(PydanticSynchrotoolDeleteSynchrofacts) def delete_synchrofacts(self, threshold, in_place=False, mode='delete'): """ Delete or extract synchronous spiking events. diff --git a/elephant/statistics.py b/elephant/statistics.py index 45d9cd283..b6fd9deb5 100644 --- a/elephant/statistics.py +++ b/elephant/statistics.py @@ -83,6 +83,9 @@ from elephant.utils import deprecated_alias, check_neo_consistency, \ is_time_quantity, round_binning_errors +from elephant.schemas.function_validator import validate_with +from elephant.schemas.schema_statistics import * + # do not import unicode_literals # (quantities rescale does not work with unicodes) @@ -102,9 +105,12 @@ "optimal_kernel_bandwidth" ] -cv = scipy.stats.variation +@validate_with(PydanticCv) +def cv(*args, **kwargs): + return scipy.stats.variation(*args, **kwargs) +@validate_with(PydanticIsi) def isi(spiketrain, axis=-1): """ Return an array containing the inter-spike intervals of the spike train. @@ -155,7 +161,7 @@ def isi(spiketrain, axis=-1): return intervals - +@validate_with(PydanticMeanFiringRate) def mean_firing_rate(spiketrain, t_start=None, t_stop=None, axis=None): """ Return the firing rate of the spike train. @@ -270,6 +276,7 @@ def mean_firing_rate(spiketrain, t_start=None, t_stop=None, axis=None): return rates +@validate_with(PydanticFanofactor) def fanofactor(spiketrains, warn_tolerance=0.1 * pq.ms): r""" Evaluates the empirical Fano factor F of the spike counts of @@ -373,6 +380,7 @@ def __variation_check(v, with_nan): return None +@validate_with(PydanticCv2) @deprecated_alias(v='time_intervals') def cv2(time_intervals, with_nan=False): r""" @@ -441,6 +449,7 @@ def cv2(time_intervals, with_nan=False): return 2. * np.mean(np.abs(cv_i)) +@validate_with(PydanticLv) @deprecated_alias(v='time_intervals') def lv(time_intervals, with_nan=False): r""" @@ -508,6 +517,7 @@ def lv(time_intervals, with_nan=False): return 3. * np.mean(np.power(cv_i, 2)) +@validate_with(PydanticLvr) def lvr(time_intervals, R=5*pq.ms, with_nan=False): r""" Calculate the measure of revised local variation LvR for a sequence of time @@ -600,6 +610,7 @@ def lvr(time_intervals, R=5*pq.ms, with_nan=False): return lvr +@validate_with(PydanticInstantaneousRate) @deprecated_alias(spiketrain='spiketrains') def instantaneous_rate(spiketrains, sampling_period, kernel='auto', cutoff=5.0, t_start=None, t_stop=None, trim=False, @@ -1061,6 +1072,7 @@ def optimal_kernel(st): return rate +@validate_with(PydanticTimeHistogram) @deprecated_alias(binsize='bin_size') def time_histogram(spiketrains, bin_size, t_start=None, t_stop=None, output='counts', binary=False): @@ -1204,6 +1216,7 @@ def _rate() -> pq.Quantity: normalization=output) +@validate_with(PydanticComplexityPdf) @deprecated_alias(binsize='bin_size') def complexity_pdf(spiketrains, bin_size): """ @@ -1418,6 +1431,7 @@ class Complexity(object): """ + @validate_with(PydanticComplexityInit) def __init__(self, spiketrains, sampling_rate=None, bin_size=None, @@ -1716,6 +1730,7 @@ def cost_function(x, N, w, dt): return C, yh +@validate_with(PydanticOptimalKernelBandwidth) @deprecated_alias(tin='times', w='bandwidth') def optimal_kernel_bandwidth(spiketimes, times=None, bandwidth=None, bootstrap=False): diff --git a/elephant/test/test_schemas.py b/elephant/test/test_schemas.py new file mode 100644 index 000000000..3db6b8162 --- /dev/null +++ b/elephant/test/test_schemas.py @@ -0,0 +1,245 @@ + +import pytest +import quantities as pq +import neo +import numpy as np + +import elephant + +from pydantic import ValidationError +from elephant.schemas.function_validator import deactivate_validation, activate_validation + +from elephant.schemas.schema_statistics import * +from elephant.schemas.schema_spike_train_correlation import * +from elephant.schemas.schema_spike_train_synchrony import * + + +def test_model_json_schema(): + # Just test that json_schema generation runs without error for all models + model_classes = [ + PydanticCovariance, + PydanticCorrelationCoefficient, + PydanticCrossCorrelationHistogram, + PydanticSpikeTimeTilingCoefficient, + PydanticSpikeTrainTimescale, + PydanticMeanFiringRate, + PydanticInstantaneousRate, + PydanticTimeHistogram, + PydanticOptimalKernelBandwidth, + PydanticIsi, + PydanticCv, + PydanticCv2, + PydanticLv, + PydanticLvr, + PydanticFanofactor, + PydanticComplexityPdf, + PydanticSpikeContrast, + PydanticComplexityInit, + PydanticSynchrotoolInit, + PydanticSynchrotoolDeleteSynchrofacts + ] + for cls in model_classes: + schema = cls.model_json_schema() + assert isinstance(schema, dict) + + +""" +Checking for consistent behavior between Elephant functions and Pydantic models. +Tests bypass validate_with decorator if it is already implemented for that function +so consistency is checked correctly +""" + +# Deactivate validation happening in the decorator of the elephant functions before all tests in this module to keep checking consistent behavior. Activates it again after all tests in this module have run. + +@pytest.fixture(scope="module", autouse=True) +def module_setup_teardown(): + deactivate_validation() + + yield + + activate_validation() + +@pytest.fixture +def make_list(): + return [0.01, 0.02, 0.05] + +@pytest.fixture +def make_ndarray(make_list): + return np.array(make_list) + +@pytest.fixture +def make_pq_single_quantity(): + return 0.05 * pq.s + +@pytest.fixture +def make_pq_multiple_quantity(make_ndarray): + return make_ndarray * pq.s + +@pytest.fixture +def make_spiketrain(make_pq_multiple_quantity): + return neo.core.SpikeTrain(make_pq_multiple_quantity, t_start=0 * pq.s, t_stop=0.1 * pq.s) + +@pytest.fixture +def make_spiketrains(make_spiketrain): + return [make_spiketrain, make_spiketrain] + +@pytest.fixture +def make_binned_spiketrain(make_spiketrain): + return elephant.conversion.BinnedSpikeTrain(make_spiketrain, bin_size=0.01 * pq.s) + +@pytest.fixture +def make_analog_signal(): + n2 = 300 + n0 = 100000 - n2 + return neo.AnalogSignal(np.array([10] * n2 + [0] * n0).reshape(n0 + n2, 1) * pq.dimensionless, sampling_period=1 * pq.s) + +@pytest.fixture +def fixture(request): + return request.getfixturevalue(request.param) + + +@pytest.mark.parametrize("elephant_fn,model_cls", [ + (elephant.statistics.mean_firing_rate, PydanticMeanFiringRate), + (elephant.statistics.isi, PydanticIsi), +]) +@pytest.mark.parametrize("fixture", [ + "make_list", + "make_spiketrain", + "make_ndarray", + "make_pq_multiple_quantity", +], indirect=["fixture"]) +def test_valid_spiketrain_input(elephant_fn, model_cls, fixture): + valid = {"spiketrain": fixture} + assert(isinstance(model_cls(**valid), model_cls)) + # just check it runs without error + elephant_fn(**valid) + + +@pytest.mark.parametrize("elephant_fn,model_cls", [ + (elephant.statistics.mean_firing_rate, PydanticMeanFiringRate), + (elephant.statistics.isi, PydanticIsi), +]) +@pytest.mark.parametrize("spiketrain", [ + 5, + "hello", +]) +def test_invalid_spiketrain(elephant_fn, model_cls, spiketrain): + invalid = {"spiketrain": spiketrain} + with pytest.raises(TypeError): + model_cls(**invalid) + with pytest.raises((TypeError, ValueError)): + elephant_fn(**invalid) + + +@pytest.mark.parametrize("elephant_fn,model_cls", [ + (elephant.statistics.time_histogram, PydanticTimeHistogram), + (elephant.statistics.complexity_pdf, PydanticComplexityPdf), +]) +def test_valid_pq_quantity(elephant_fn, model_cls, make_spiketrains, make_pq_single_quantity): + valid = {"spiketrains": make_spiketrains, "bin_size": make_pq_single_quantity} + assert(isinstance(model_cls(**valid), model_cls)) + # just check it runs without error + elephant_fn(**valid) + + +@pytest.mark.parametrize("elephant_fn,model_cls", [ + (elephant.statistics.time_histogram, PydanticTimeHistogram), + (elephant.statistics.complexity_pdf, PydanticComplexityPdf), +]) +@pytest.mark.parametrize("pq_quantity", [ + 5, + "hello", + [0.01, 0.02] +]) +def test_invalid_pq_quantity(elephant_fn, model_cls, make_spiketrains, pq_quantity): + invalid = {"spiketrains": make_spiketrains, "bin_size": pq_quantity} + with pytest.raises(TypeError): + model_cls(**invalid) + with pytest.raises(AttributeError): + elephant_fn(**invalid) + + + +@pytest.mark.parametrize("elephant_fn,model_cls", [ + (elephant.statistics.instantaneous_rate, PydanticInstantaneousRate), +]) +@pytest.mark.parametrize("fixture", [ + "make_list", + "make_ndarray", + "make_pq_multiple_quantity", +], indirect=["fixture"]) +def test_invalid_spiketrains(elephant_fn, model_cls, fixture, make_pq_single_quantity): + invalid = {"spiketrains": fixture, "sampling_period": make_pq_single_quantity} + with pytest.raises(TypeError): + model_cls(**invalid) + with pytest.raises(TypeError): + elephant_fn(**invalid) + +@pytest.mark.parametrize("output", [ + "counts", + "mean", + "rate", +]) +def test_valid_enum(output, make_spiketrains, make_pq_single_quantity): + valid = {"spiketrains": make_spiketrains, "bin_size": make_pq_single_quantity, "output": output} + assert(isinstance(PydanticTimeHistogram(**valid), PydanticTimeHistogram)) + # just check it runs without error + elephant.statistics.time_histogram(**valid) + +@pytest.mark.parametrize("output", [ + "countsfagre", + 5, + "Counts", + "counts ", + " counts", + "counts\n" +]) +def test_invalid_enum(output, make_spiketrains, make_pq_single_quantity): + invalid = {"spiketrains": make_spiketrains, "bin_size": make_pq_single_quantity, "output": output} + with pytest.raises(ValidationError): + PydanticTimeHistogram(**invalid) + with pytest.raises(ValueError): + elephant.statistics.time_histogram(**invalid) + + +def test_valid_binned_spiketrain(make_binned_spiketrain): + valid = {"binned_spiketrain": make_binned_spiketrain} + assert(isinstance(PydanticCovariance(**valid), PydanticCovariance)) + # just check it runs without error + elephant.spike_train_correlation.covariance(**valid) + +def test_invalid_binned_spiketrain(make_spiketrain): + invalid = {"binned_spiketrain": make_spiketrain} + with pytest.raises(TypeError): + PydanticCovariance(**invalid) + with pytest.raises(AttributeError): + elephant.spike_train_correlation.covariance(**invalid) + +@pytest.mark.parametrize("elephant_fn,model_cls,invalid", [ + (elephant.statistics.instantaneous_rate, PydanticInstantaneousRate, {"spiketrains": [], "sampling_period": 0.01 * pq.s}), + (elephant.statistics.optimal_kernel_bandwidth, PydanticOptimalKernelBandwidth, {"spiketimes": np.array([])}), + (elephant.statistics.cv2, PydanticCv2, {"time_intervals": np.array([])*pq.s}), +]) +def test_invalid_empty_input(elephant_fn, model_cls, invalid): + + with pytest.raises(ValueError): + model_cls(**invalid) + with pytest.raises((ValueError,TypeError)): + elephant_fn(**invalid) + +@pytest.mark.parametrize("elephant_fn,model_cls,parameter_name,empty_input", [ + (elephant.spike_train_correlation.covariance, PydanticCovariance, "binned_spiketrain", elephant.conversion.BinnedSpikeTrain(neo.core.SpikeTrain(np.array([])*pq.s, t_start=0*pq.s, t_stop=1*pq.s), bin_size=0.01*pq.s)), +]) +def test_warning_empty_input(elephant_fn, model_cls, parameter_name, empty_input): + warning = {parameter_name: empty_input} + with pytest.warns(UserWarning): + model_cls(**warning) + with pytest.warns(UserWarning): + elephant_fn(**warning) + + +def test_valid_Complexity(make_spiketrains, make_pq_single_quantity): + valid = { "spiketrains": make_spiketrains, "bin_size": make_pq_single_quantity } + assert(isinstance(PydanticComplexityInit(**valid), PydanticComplexityInit)) + # just check it runs without error + elephant.statistics.Complexity(**valid) \ No newline at end of file diff --git a/elephant/test/test_spike_train_correlation.py b/elephant/test/test_spike_train_correlation.py index 90de65ea1..75682332b 100644 --- a/elephant/test/test_spike_train_correlation.py +++ b/elephant/test/test_spike_train_correlation.py @@ -913,7 +913,7 @@ def test_timescale_errors(self): # Tau max with no units tau_max = 1 - self.assertRaises(ValueError, + self.assertRaises((ValueError, TypeError), sc.spike_train_timescale, spikes_bin, tau_max) # Tau max that is not a multiple of the binsize diff --git a/elephant/test/test_statistics.py b/elephant/test/test_statistics.py index 426111810..a94c4aa68 100644 --- a/elephant/test/test_statistics.py +++ b/elephant/test/test_statistics.py @@ -383,7 +383,7 @@ def test_lv_with_list(self): def test_lv_raise_error(self): seq = self.test_seq self.assertRaises(ValueError, statistics.lv, []) - self.assertRaises(ValueError, statistics.lv, 1) + self.assertRaises((ValueError, TypeError), statistics.lv, 1) self.assertRaises(ValueError, statistics.lv, np.array([seq, seq])) def test_2short_spike_train(self): @@ -430,7 +430,7 @@ def test_lvr_with_list(self): def test_lvr_raise_error(self): seq = self.test_seq self.assertRaises(ValueError, statistics.lvr, []) - self.assertRaises(ValueError, statistics.lvr, 1) + self.assertRaises((ValueError, TypeError), statistics.lvr, 1) self.assertRaises(ValueError, statistics.lvr, np.array([seq, seq])) self.assertRaises(ValueError, statistics.lvr, seq, -1 * pq.ms) @@ -478,7 +478,7 @@ def test_cv2_with_list(self): def test_cv2_raise_error(self): seq = self.test_seq self.assertRaises(ValueError, statistics.cv2, []) - self.assertRaises(ValueError, statistics.cv2, 1) + self.assertRaises((ValueError, TypeError), statistics.cv2, 1) self.assertRaises(ValueError, statistics.cv2, np.array([seq, seq])) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index b3b9d6f98..3021ae5f9 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -4,3 +4,4 @@ quantities>=0.14.1 scipy>=1.10.0 six>=1.10.0 tqdm +pydantic>=2.0.0 \ No newline at end of file