Skip to content

Commit ca247c0

Browse files
committed
Make test stricter by checking for the exact Error Type. Also Fixed Bugs
1 parent ae94ee6 commit ca247c0

File tree

2 files changed

+34
-25
lines changed

2 files changed

+34
-25
lines changed

elephant/schemas/function_validator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ def decorator(func):
1515
@wraps(func)
1616
def wrapper(*args, **kwargs):
1717

18+
print(skip_validation)
1819
if not skip_validation:
1920
# Bind args & kwargs to function parameters
2021
bound = sig.bind_partial(*args, **kwargs)
@@ -30,7 +31,9 @@ def wrapper(*args, **kwargs):
3031
return decorator
3132

3233
def activate_validation():
34+
global skip_validation
3335
skip_validation = False
3436

3537
def deactivate_validation():
38+
global skip_validation
3639
skip_validation = True

elephant/test/test_schemas.py

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66

77
import elephant
88

9-
from elephant.schemas.function_validator import deactivate_validation
9+
from pydantic import ValidationError
10+
from elephant.schemas.function_validator import deactivate_validation, activate_validation
1011

1112
from elephant.schemas.schema_statistics import *
1213
from elephant.schemas.schema_spike_train_correlation import *
@@ -48,10 +49,15 @@ def test_model_json_schema():
4849
so consistency is checked correctly
4950
"""
5051

51-
# Deactivate validation happening in the decorator of the elephant functions for all tests in this module to keep checking consistent behavior
52-
@pytest.fixture(autouse=True)
53-
def disable_validation_for_tests():
54-
deactivate_validation()
52+
# Deactivate validation happening in the decorator of the elephant functions before all tests in this module to keep checking consistent behavior. Activates it again after all tests in this module have run.
53+
54+
@pytest.fixture(scope="module", autouse=True)
55+
def module_setup_teardown():
56+
deactivate_validation()
57+
58+
yield
59+
60+
activate_validation()
5561

5662
@pytest.fixture
5763
def make_list():
@@ -119,9 +125,9 @@ def test_valid_spiketrain_input(elephant_fn, model_cls, fixture):
119125
])
120126
def test_invalid_spiketrain(elephant_fn, model_cls, spiketrain):
121127
invalid = {"spiketrain": spiketrain}
122-
with pytest.raises(Exception):
128+
with pytest.raises(TypeError):
123129
model_cls(**invalid)
124-
with pytest.raises(Exception):
130+
with pytest.raises((TypeError, ValueError)):
125131
elephant_fn(**invalid)
126132

127133

@@ -147,9 +153,9 @@ def test_valid_pq_quantity(elephant_fn, model_cls, make_spiketrains, make_pq_sin
147153
])
148154
def test_invalid_pq_quantity(elephant_fn, model_cls, make_spiketrains, pq_quantity):
149155
invalid = {"spiketrains": make_spiketrains, "bin_size": pq_quantity}
150-
with pytest.raises(Exception):
156+
with pytest.raises(TypeError):
151157
model_cls(**invalid)
152-
with pytest.raises(Exception):
158+
with pytest.raises(AttributeError):
153159
elephant_fn(**invalid)
154160

155161

@@ -164,9 +170,9 @@ def test_invalid_pq_quantity(elephant_fn, model_cls, make_spiketrains, pq_quanti
164170
], indirect=["fixture"])
165171
def test_invalid_spiketrains(elephant_fn, model_cls, fixture, make_pq_single_quantity):
166172
invalid = {"spiketrains": fixture, "sampling_period": make_pq_single_quantity}
167-
with pytest.raises(Exception):
173+
with pytest.raises(TypeError):
168174
model_cls(**invalid)
169-
with pytest.raises(Exception):
175+
with pytest.raises(TypeError):
170176
elephant_fn(**invalid)
171177

172178
@pytest.mark.parametrize("output", [
@@ -190,9 +196,9 @@ def test_valid_enum(output, make_spiketrains, make_pq_single_quantity):
190196
])
191197
def test_invalid_enum(output, make_spiketrains, make_pq_single_quantity):
192198
invalid = {"spiketrains": make_spiketrains, "bin_size": make_pq_single_quantity, "output": output}
193-
with pytest.raises(Exception):
199+
with pytest.raises(ValidationError):
194200
PydanticTimeHistogram(**invalid)
195-
with pytest.raises(Exception):
201+
with pytest.raises(ValueError):
196202
elephant.statistics.time_histogram(**invalid)
197203

198204

@@ -204,31 +210,31 @@ def test_valid_binned_spiketrain(make_binned_spiketrain):
204210

205211
def test_invalid_binned_spiketrain(make_spiketrain):
206212
invalid = {"binned_spiketrain": make_spiketrain}
207-
with pytest.raises(Exception):
213+
with pytest.raises(TypeError):
208214
PydanticCovariance(**invalid)
209-
with pytest.raises(Exception):
215+
with pytest.raises(AttributeError):
210216
elephant.spike_train_correlation.covariance(**invalid)
211217

212-
@pytest.mark.parametrize("elephant_fn,model_cls,parameter_name,empty_input", [
213-
(elephant.statistics.instantaneous_rate, PydanticInstantaneousRate, "spiketrains", []),
214-
(elephant.statistics.optimal_kernel_bandwidth, PydanticOptimalKernelBandwidth, "spiketimes", np.array([])),
215-
(elephant.statistics.cv2, PydanticCv2, "time_intervals", np.array([])*pq.s),
218+
@pytest.mark.parametrize("elephant_fn,model_cls,invalid", [
219+
(elephant.statistics.instantaneous_rate, PydanticInstantaneousRate, {"spiketrains": [], "sampling_period": 0.01 * pq.s}),
220+
(elephant.statistics.optimal_kernel_bandwidth, PydanticOptimalKernelBandwidth, {"spiketimes": np.array([])}),
221+
(elephant.statistics.cv2, PydanticCv2, {"time_intervals": np.array([])*pq.s}),
216222
])
217-
def test_invalid_empty_input(elephant_fn, model_cls, parameter_name, empty_input):
218-
invalid = {parameter_name: empty_input}
219-
with pytest.raises(Exception):
223+
def test_invalid_empty_input(elephant_fn, model_cls, invalid):
224+
225+
with pytest.raises(ValueError):
220226
model_cls(**invalid)
221-
with pytest.raises(Exception):
227+
with pytest.raises((ValueError,TypeError)):
222228
elephant_fn(**invalid)
223229

224230
@pytest.mark.parametrize("elephant_fn,model_cls,parameter_name,empty_input", [
225231
(elephant.spike_train_correlation.covariance, PydanticCovariance, "binned_spiketrain", elephant.conversion.BinnedSpikeTrain(neo.core.SpikeTrain(np.array([])*pq.s, t_start=0*pq.s, t_stop=1*pq.s), bin_size=0.01*pq.s)),
226232
])
227233
def test_warning_empty_input(elephant_fn, model_cls, parameter_name, empty_input):
228234
warning = {parameter_name: empty_input}
229-
with pytest.warns(Warning):
235+
with pytest.warns(UserWarning):
230236
model_cls(**warning)
231-
with pytest.warns(Warning):
237+
with pytest.warns(UserWarning):
232238
elephant_fn(**warning)
233239

234240

0 commit comments

Comments
 (0)