Skip to content
Open
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
f70c9cd
Added schemas Folder with __init__.py
JanNolten Nov 7, 2025
e007679
Added field_validator.py to group repeated validation
JanNolten Nov 7, 2025
ccac5cc
Added field_serializer.py to group repeated serialization
JanNolten Nov 7, 2025
11c2a35
added function_validator.py to validate elephant functions
JanNolten Nov 7, 2025
46ff38b
Added Pydantic Models for statistics
JanNolten Nov 7, 2025
5fb18dc
Added Pydantic Models for spike_train_correlation
JanNolten Nov 7, 2025
d5324bf
Added Pydantic Models for spike_train_synchrony
JanNolten Nov 7, 2025
c99b6c0
Original arguments are passed into the function
JanNolten Nov 7, 2025
cb533cc
Added pytest.ini to .gitignore
JanNolten Nov 7, 2025
71b7dc0
Added tests and option to skip validation
JanNolten Nov 7, 2025
7f4f5ef
Transfering Bug fixes
JanNolten Nov 7, 2025
b282cc3
Transfering Bug fixes
JanNolten Nov 7, 2025
9c3402a
Transfering Bug fixes
JanNolten Nov 7, 2025
ac00866
Implemented validation for statistics
JanNolten Nov 7, 2025
a95f9f8
Implemented validation for spike_train_correlation
JanNolten Nov 7, 2025
69745c2
Implemented validation for spike_train_synchrony
JanNolten Nov 7, 2025
460e9cd
Allowed some ValueErrors to also be TypeErrors
JanNolten Nov 7, 2025
bf9b837
Merge branch 'NeuralEnsemble:master' into feature/partial
JanNolten Nov 7, 2025
fac02e1
Removed ; at end of lines
JanNolten Nov 10, 2025
4fdd64a
Added Pydantic to requirements
JanNolten Nov 10, 2025
b5f0ff7
Merge branch 'NeuralEnsemble:master' into feature/partial
JanNolten Nov 11, 2025
7d933a4
Removed Self from typing, because it only works in python>=3.11.0
JanNolten Nov 11, 2025
932f1d4
Added ability to disable validation globally
JanNolten Nov 11, 2025
1f58b12
Allow t_start to be negative because it should be able to be used tha…
JanNolten Nov 11, 2025
6f2b7d3
Allowed all t_start and t_stop to be negative, becuase they could be …
JanNolten Nov 18, 2025
c46228d
Removed the option to skip validation with the extra kwargs not_valid…
JanNolten Nov 24, 2025
ae94ee6
Simplified test to make them more understandable
JanNolten Nov 24, 2025
ca247c0
Make test stricter by checking for the exact Error Type. Also Fixed Bugs
JanNolten Nov 24, 2025
91df787
Forgot to remove a print statement
JanNolten Nov 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,5 @@ ignored/

# neo logs
**/logs

pytest.ini
Empty file added elephant/schemas/__init__.py
Empty file.
9 changes: 9 additions & 0 deletions elephant/schemas/field_serializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import quantities as pq

def serialize_quantity(value: pq.Quantity) -> dict:
if value is None:
return None
return {
"value": value.magnitude,
"unit": value.dimensionality
}
247 changes: 247 additions & 0 deletions elephant/schemas/field_validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
import numpy as np
import quantities as pq
import neo
import elephant
from enum import Enum
from typing import Any
import warnings

def get_length(obj) -> int:
"""
Return the length (number of elements) of various supported datatypes:
- list
- numpy.ndarray
- pq.Quantity
- neo.SpikeTrain

Returns
-------
int
The number of elements or spikes in the object.

Raises
------
TypeError
If the object type is not supported.
"""
if obj is None:
raise ValueError("Cannot get length of None")

if isinstance(obj, elephant.trials.Trials):
return len(obj.trials)
elif isinstance(obj, elephant.conversion.BinnedSpikeTrain):
return obj.n_bins
elif isinstance(obj, neo.SpikeTrain):
return len(obj)
elif isinstance(obj, pq.Quantity):
return obj.size
elif isinstance(obj, np.ndarray):
return obj.size
elif isinstance(obj, (list,tuple)):
return len(obj)
elif isinstance(obj, neo.core.spiketrainlist.SpikeTrainList):
return len(obj)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added getting length of SpikeTrainList, to check if it is empty or has a min_length.



else:
raise TypeError(
f"Unsupported type for length computation: {type(obj).__name__}"
)

def is_sorted(obj) -> bool:
if obj is None:
raise ValueError("Cannot check sortedness of None")

if isinstance(obj, (list, np.ndarray, pq.Quantity)):
arr = np.asarray(obj)
return np.all(arr[:-1] <= arr[1:])
elif isinstance(obj, neo.SpikeTrain):
arr = obj.magnitude # Get the underlying numpy array of spike times
return np.all(arr[:-1] <= arr[1:])
return False

def is_matrix(obj) -> bool:
if obj is None:
raise ValueError("Cannot check matrix of None")
if isinstance(obj, (list, np.ndarray, pq.Quantity)):
arr = np.asarray(obj)
return arr.ndim >= 2
elif isinstance(obj, neo.SpikeTrain):
arr = obj.magnitude # Get the underlying numpy array of spike times
return arr.ndim >= 2
return False

def validate_covariance_matrix_rank_deficient(obj, info):
"""
Check if the covariance matrix of the given object is rank deficient.
Should work for elephant.trials.Trials, list of neo.core.spiketrainlist.SpikeTrainList or list of list of neo.core.SpikeTrain.
"""
return obj

def validate_type(
value,
info,
allowed_types: tuple,
allow_none: bool,
):
"""Generic type validation helper."""
if value is None:
if allow_none:
return value
raise ValueError(f"{info.field_name} cannot be None")

if not isinstance(value, allowed_types):
raise TypeError(f"{info.field_name} must be one of {allowed_types}, not {type(value).__name__}")
return value

def validate_length(
value,
info: str,
min_length: int,
warning: bool
):
if min_length>0:
if get_length(value) < min_length:
if warning:
warnings.warn(f"{info.field_name} has less than {min_length} elements", UserWarning)
else:
raise ValueError(f"{info.field_name} must contain at least {min_length} elements")
return value

def validate_type_length(value, info, allowed_types: tuple, allow_none: bool, min_length: int, warning: bool = False):
validate_type(value, info, allowed_types, allow_none)
if value is not None:
validate_length(value, info, min_length, warning)
return value

def validate_array_content(value, info, allowed_types: tuple, allow_none: bool, min_length: int, allowed_content_types: tuple, min_length_content: int = 0):
validate_type_length(value, info, allowed_types, allow_none, min_length)
hasContentLength = False
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If hasContentLength is not set to True in the loop, then it would have never been initialized.

for i, item in enumerate(value):
if not isinstance(item, allowed_content_types):
raise TypeError(f"Element {i} in {info.field_name} must be {allowed_content_types}, not {type(item).__name__}")
if min_length_content > 0 and get_length(item) >= min_length_content:
hasContentLength = True
if(min_length_content > 0 and not hasContentLength):
raise ValueError(f"{info.field_name} must contain at least one element with at least {min_length_content} elements")

return value

# ---- Specialized validation helpers ----

def validate_spiketrain(value, info, allowed_types=(list, neo.SpikeTrain, pq.Quantity, np.ndarray), allow_none = False, min_length = 1, check_sorted = False):
validate_type_length(value, info, allowed_types, allow_none, min_length)
if(check_sorted):
if value is not None and not is_sorted(value):
warnings.warn(f"{info.field_name} is not sorted", UserWarning)
if(isinstance(value, neo.SpikeTrain)):
if value.t_start is not None and value.t_stop is not None:
if value.t_start > value.t_stop:
raise ValueError(f"{info.field_name} has t_start > t_stop")
return value

def validate_spiketrains(value, info, allowed_types = (list,), allow_none = False, min_length = 1, allowed_content_types = (list, neo.SpikeTrain, pq.Quantity, np.ndarray), min_length_content = 0):
validate_array_content(value, info, allowed_types, allow_none, min_length, allowed_content_types, min_length_content)
return value

def validate_spiketrains_matrix(value, info, allowed_types = (elephant.trials.Trials, list[neo.core.spiketrainlist.SpikeTrainList], list[list[neo.core.SpikeTrain]]), allow_none = False, min_length = 1, check_rank_deficient = False):
if isinstance(value, list):
validate_spiketrains(value, info, allowed_content_types=(neo.core.spiketrainlist.SpikeTrainList,list[neo.core.SpikeTrain],))
else:
validate_type(value, info, (elephant.trials.Trials,), allow_none=False)
if check_rank_deficient:
return validate_covariance_matrix_rank_deficient(value, info)
return value

def validate_time(value, info, allowed_types=(float, pq.Quantity) ,allow_none=True):
if(isinstance(value, np.ndarray) and value.size==1):
value = value.item()

validate_type(value, info, allowed_types, allow_none)
return value

def validate_quantity(value, info, allow_none=False):
validate_type(value, info, (pq.Quantity,), allow_none)
return value

def validate_time_intervals(value, info, allowed_types = (list, pq.Quantity, np.ndarray), allow_none = False, min_length=0, check_matrix = False):
validate_type_length(value, info, allowed_types, allow_none, min_length)
if check_matrix:
if value is not None and is_matrix(value):
raise ValueError(f"{info.field_name} is not allowed to be a matrix")
return value

def validate_array(value, info, allowed_types=(list, np.ndarray, tuple) , allow_none=False, min_length=1, allowed_content_types = None, min_length_content = 0):
if allowed_content_types is None:
validate_type_length(value, info, allowed_types, allow_none, min_length)
else:
validate_array_content(value, info, allowed_types, allow_none, min_length, allowed_content_types, min_length_content)
return value

def validate_binned_spiketrain(value, info, allowed_types=(elephant.conversion.BinnedSpikeTrain,), allow_none=False, min_length=1):
validate_type_length(value, info, allowed_types, allow_none, min_length, warning=True)
if value is not None and isinstance(value, elephant.conversion.BinnedSpikeTrain):
spmat = value.sparse_matrix

# Check for empty spike trains
n_spikes_per_row = spmat.sum(axis=1)
if n_spikes_per_row.min() == 0:
warnings.warn(
f'Detected empty spike trains (rows) in the {info.field_name}.', UserWarning)
return value

def validate_dict_enum_types(value : dict[Enum, Any], info, typeDictionary: dict[Enum, type]):
for key, val in value.items():
if not isinstance(val, typeDictionary[key]):
raise TypeError(f"Value for key {key} in {info.field_name} must be of type {typeDictionary[key].__name__}, not {type(val).__name__}")
return value

def validate_key_in_tuple(value : str, info, t: tuple):
if value not in t:
raise ValueError(f"{info}:{value} is not in the options {t}")
return value


# ---- Model validation helpers ----

def model_validate_spiketrains_same_t_start_stop(spiketrains, t_start, t_stop, name: str = "spiketrains", warning: bool = False):
if(t_start is None or t_stop is None):
first = True
for i, item in enumerate(spiketrains):
if first:
t_start = item.t_start
t_stop = item.t_stop
first = False
else:
if t_start is None and item.t_start != t_start:
if warning:
warnings.warn(f"{name} has different t_start values among its elements", UserWarning)
else:
raise ValueError(f"{name} has different t_start values among its elements")
if t_stop is None and item.t_stop != t_stop:
if warning:
warnings.warn(f"{name} has different t_stop values among its elements", UserWarning)
else:
raise ValueError(f"{name} has different t_stop values among its elements")
else:
if t_start>t_stop:
raise ValueError(f"{name} has t_start > t_stop")

def model_validate_two_spiketrains_same_t_start_stop(spiketrain_i, spiketrain_j):
if spiketrain_i.t_start != spiketrain_j.t_start:
raise ValueError("spiketrain_i and spiketrain_j need to have the same t_start")
if spiketrain_i.t_stop != spiketrain_j.t_stop:
raise ValueError("spiketrain_i and spiketrain_j need to have the same t_stop")

def model_validate_time_intervals_with_nan(time_intervals , with_nan, name: str = "time_intervals"):
if get_length(time_intervals)<2:
if(with_nan):
warnings.warn(f"{name} has less than two entries so a np.Nan will be generated", UserWarning)
else:
raise ValueError(f"{name} has less than two entries")

def model_validate_binned_spiketrain_fast(binned_spiketrain, fast, name: str = "binned_spiketrain"):
if(fast and np.max(binned_spiketrain.shape) > np.iinfo(np.int32).max):
raise MemoryError(f"{name} is too large for fast=True option")

32 changes: 32 additions & 0 deletions elephant/schemas/function_validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from functools import wraps
from inspect import signature
from pydantic import BaseModel

def validate_with(model_class: type[BaseModel]):
"""
A decorator that validates the inputs of a function using a Pydantic model.
Works for both positional and keyword arguments.
"""
def decorator(func):
sig = signature(func)

@wraps(func)
def wrapper(*args, **kwargs):

if kwargs.pop("not_validate", False):
# skip validation, call inner function directly
return func(*args, **kwargs)

Copy link
Author

@JanNolten JanNolten Nov 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be able to call function without validation, if it does not work how it should or if the user knows he has valid data and therefore wants to save some calculation time.

# Bind args & kwargs to function parameters
bound = sig.bind_partial(*args, **kwargs)
bound.apply_defaults()
data = bound.arguments

# Validate using Pydantic
model_class(**data)

# Call function
return func(*args, **kwargs)
wrapper._is_validate_with = True
return wrapper
return decorator
Loading