Skip to content

Commit 88ca4dc

Browse files
committed
refactor test fixtures
1 parent a506d40 commit 88ca4dc

File tree

9 files changed

+162
-44
lines changed

9 files changed

+162
-44
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ repos:
3030
rev: 25.1.0 # matching versions in pyproject.toml and github actions
3131
hooks:
3232
- id: black
33-
args: ["--check", "-v", "src", "tests", "--diff"] # --required-version is conflicting with pre-commit
33+
args: ["-v", "src", "tests", "--diff"] # --required-version is conflicting with pre-commit
3434
- repo: https://github.com/PyCQA/flake8
3535
rev: 7.3.0
3636
hooks:

tests/conftest.py

Lines changed: 112 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,57 @@ def pytest_configure(config):
4545
pass
4646

4747

48+
@pytest.fixture
49+
def clean_autopopulate(experiment, trial, ephys):
50+
"""
51+
Explicit cleanup fixture for autopopulate tests.
52+
53+
Cleans experiment/trial/ephys tables after test completes.
54+
Tests must explicitly request this fixture to get cleanup.
55+
"""
56+
yield
57+
# Cleanup after test - delete in reverse dependency order
58+
ephys.delete()
59+
trial.delete()
60+
experiment.delete()
61+
62+
63+
@pytest.fixture
64+
def clean_jobs(schema_any):
65+
"""
66+
Explicit cleanup fixture for jobs tests.
67+
68+
Cleans jobs table before test runs.
69+
Tests must explicitly request this fixture to get cleanup.
70+
"""
71+
try:
72+
schema_any.jobs.delete()
73+
except DataJointError:
74+
pass
75+
yield
76+
77+
78+
@pytest.fixture
79+
def clean_test_tables(test, test_extra, test_no_extra):
80+
"""
81+
Explicit cleanup fixture for relation tests using test tables.
82+
83+
Ensures test table has lookup data and restores clean state after test.
84+
Tests must explicitly request this fixture to get cleanup.
85+
"""
86+
# Ensure lookup data exists before test
87+
if not test:
88+
test.insert(test.contents, skip_duplicates=True)
89+
90+
yield
91+
92+
# Restore original state after test
93+
test.delete()
94+
test.insert(test.contents, skip_duplicates=True)
95+
test_extra.delete()
96+
test_no_extra.delete()
97+
98+
4899
# Global container registry for cleanup
49100
_active_containers = set()
50101
_docker_client = None
@@ -547,7 +598,7 @@ def mock_cache(tmpdir_factory):
547598
dj.config["cache"] = og_cache
548599

549600

550-
@pytest.fixture
601+
@pytest.fixture(scope="module")
551602
def schema_any(connection_test, prefix):
552603
schema_any = dj.Schema(
553604
prefix + "_test1", schema.LOCALS_ANY, connection=connection_test
@@ -603,6 +654,63 @@ def schema_any(connection_test, prefix):
603654
schema_any.drop()
604655

605656

657+
@pytest.fixture
658+
def schema_any_fresh(connection_test, prefix):
659+
"""Function-scoped schema_any for tests that need fresh schema state."""
660+
schema_any = dj.Schema(
661+
prefix + "_test1_fresh", schema.LOCALS_ANY, connection=connection_test
662+
)
663+
assert schema.LOCALS_ANY, "LOCALS_ANY is empty"
664+
try:
665+
schema_any.jobs.delete()
666+
except DataJointError:
667+
pass
668+
schema_any(schema.TTest)
669+
schema_any(schema.TTest2)
670+
schema_any(schema.TTest3)
671+
schema_any(schema.NullableNumbers)
672+
schema_any(schema.TTestExtra)
673+
schema_any(schema.TTestNoExtra)
674+
schema_any(schema.Auto)
675+
schema_any(schema.User)
676+
schema_any(schema.Subject)
677+
schema_any(schema.Language)
678+
schema_any(schema.Experiment)
679+
schema_any(schema.Trial)
680+
schema_any(schema.Ephys)
681+
schema_any(schema.Image)
682+
schema_any(schema.UberTrash)
683+
schema_any(schema.UnterTrash)
684+
schema_any(schema.SimpleSource)
685+
schema_any(schema.SigIntTable)
686+
schema_any(schema.SigTermTable)
687+
schema_any(schema.DjExceptionName)
688+
schema_any(schema.ErrorClass)
689+
schema_any(schema.DecimalPrimaryKey)
690+
schema_any(schema.IndexRich)
691+
schema_any(schema.ThingA)
692+
schema_any(schema.ThingB)
693+
schema_any(schema.ThingC)
694+
schema_any(schema.ThingD)
695+
schema_any(schema.ThingE)
696+
schema_any(schema.Parent)
697+
schema_any(schema.Child)
698+
schema_any(schema.ComplexParent)
699+
schema_any(schema.ComplexChild)
700+
schema_any(schema.SubjectA)
701+
schema_any(schema.SessionA)
702+
schema_any(schema.SessionStatusA)
703+
schema_any(schema.SessionDateA)
704+
schema_any(schema.Stimulus)
705+
schema_any(schema.Longblob)
706+
yield schema_any
707+
try:
708+
schema_any.jobs.delete()
709+
except DataJointError:
710+
pass
711+
schema_any.drop()
712+
713+
606714
@pytest.fixture
607715
def thing_tables(schema_any):
608716
a = schema.ThingA()
@@ -623,7 +731,7 @@ def thing_tables(schema_any):
623731
yield a, b, c, d, e
624732

625733

626-
@pytest.fixture
734+
@pytest.fixture(scope="module")
627735
def schema_simp(connection_test, prefix):
628736
schema = dj.Schema(
629737
prefix + "_relational", schema_simple.LOCALS_SIMPLE, connection=connection_test
@@ -653,7 +761,7 @@ def schema_simp(connection_test, prefix):
653761
schema.drop()
654762

655763

656-
@pytest.fixture
764+
@pytest.fixture(scope="module")
657765
def schema_adv(connection_test, prefix):
658766
schema = dj.Schema(
659767
prefix + "_advanced",
@@ -694,7 +802,7 @@ def schema_ext(
694802
schema.drop()
695803

696804

697-
@pytest.fixture
805+
@pytest.fixture(scope="module")
698806
def schema_uuid(connection_test, prefix):
699807
schema = dj.Schema(
700808
prefix + "_test1",

tests/test_alter.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414

1515

1616
@pytest.fixture
17-
def schema_alter(connection_test, schema_any):
18-
# Overwrite Experiment and Parent nodes
19-
schema_any(Experiment, context=LOCALS_ALTER)
20-
schema_any(Parent, context=LOCALS_ALTER)
21-
yield schema_any
22-
schema_any.drop()
17+
def schema_alter(connection_test, schema_any_fresh):
18+
# Overwrite Experiment and Parent nodes using fresh schema
19+
schema_any_fresh(Experiment, context=LOCALS_ALTER)
20+
schema_any_fresh(Parent, context=LOCALS_ALTER)
21+
yield schema_any_fresh
22+
schema_any_fresh.drop()
2323

2424

2525
class TestAlter:

tests/test_autopopulate.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from . import schema
88

99

10-
def test_populate(trial, subject, experiment, ephys, channel):
10+
def test_populate(clean_autopopulate, trial, subject, experiment, ephys, channel):
1111
# test simple populate
1212
assert subject, "root tables are empty"
1313
assert not experiment, "table already filled?"
@@ -33,7 +33,7 @@ def test_populate(trial, subject, experiment, ephys, channel):
3333
assert channel
3434

3535

36-
def test_populate_with_success_count(subject, experiment, trial):
36+
def test_populate_with_success_count(clean_autopopulate, subject, experiment, trial):
3737
# test simple populate
3838
assert subject, "root tables are empty"
3939
assert not experiment, "table already filled?"
@@ -51,7 +51,7 @@ def test_populate_with_success_count(subject, experiment, trial):
5151
assert len(trial.key_source & trial) == success_count
5252

5353

54-
def test_populate_key_list(subject, experiment, trial):
54+
def test_populate_key_list(clean_autopopulate, subject, experiment, trial):
5555
# test simple populate
5656
assert subject, "root tables are empty"
5757
assert not experiment, "table already filled?"
@@ -63,7 +63,7 @@ def test_populate_key_list(subject, experiment, trial):
6363
assert n == ret["success_count"]
6464

6565

66-
def test_populate_exclude_error_and_ignore_jobs(schema_any, subject, experiment):
66+
def test_populate_exclude_error_and_ignore_jobs(clean_autopopulate, schema_any, subject, experiment):
6767
# test simple populate
6868
assert subject, "root tables are empty"
6969
assert not experiment, "table already filled?"
@@ -79,7 +79,7 @@ def test_populate_exclude_error_and_ignore_jobs(schema_any, subject, experiment)
7979
assert len(experiment.key_source & experiment) == len(experiment.key_source) - 2
8080

8181

82-
def test_allow_direct_insert(subject, experiment):
82+
def test_allow_direct_insert(clean_autopopulate, subject, experiment):
8383
assert subject, "root tables are empty"
8484
key = subject.fetch("KEY", limit=1)[0]
8585
key["experiment_id"] = 1000
@@ -88,14 +88,14 @@ def test_allow_direct_insert(subject, experiment):
8888

8989

9090
@pytest.mark.parametrize("processes", [None, 2])
91-
def test_multi_processing(subject, experiment, processes):
91+
def test_multi_processing(clean_autopopulate, subject, experiment, processes):
9292
assert subject, "root tables are empty"
9393
assert not experiment, "table already filled?"
9494
experiment.populate(processes=None)
9595
assert len(experiment) == len(subject) * experiment.fake_experiments_per_subject
9696

9797

98-
def test_allow_insert(subject, experiment):
98+
def test_allow_insert(clean_autopopulate, subject, experiment):
9999
assert subject, "root tables are empty"
100100
key = subject.fetch("KEY")[0]
101101
key["experiment_id"] = 1001

tests/test_cascading_delete.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,17 @@
88

99
@pytest.fixture
1010
def schema_simp_pop(schema_simp):
11+
# Clean up tables first to ensure fresh state with module-scoped schema
12+
# Delete in reverse dependency order
13+
Profile().delete()
14+
Website().delete()
15+
G().delete()
16+
E().delete()
17+
D().delete()
18+
B().delete()
19+
L().delete()
20+
A().delete()
21+
1122
A().insert(A.contents, skip_duplicates=True)
1223
L().insert(L.contents, skip_duplicates=True)
1324
B().populate()

tests/test_declare.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ class BadName(dj.Manual):
268268
schema_any(BadName)
269269

270270

271-
def test_bad_fk_rename(schema_any):
271+
def test_bad_fk_rename(schema_any_fresh):
272272
"""issue #381"""
273273

274274
class A(dj.Manual):
@@ -281,9 +281,9 @@ class B(dj.Manual):
281281
b -> A # invalid, the new syntax is (b) -> A
282282
"""
283283

284-
schema_any(A)
284+
schema_any_fresh(A)
285285
with pytest.raises(dj.DataJointError):
286-
schema_any(B)
286+
schema_any_fresh(B)
287287

288288

289289
def test_primary_nullable_foreign_key(schema_any):
@@ -401,7 +401,7 @@ def test_add_hidden_timestamp_default_value():
401401
), "Default value for add_hidden_timestamp is not False"
402402

403403

404-
def test_add_hidden_timestamp_enabled(enable_add_hidden_timestamp, schema_any):
404+
def test_add_hidden_timestamp_enabled(enable_add_hidden_timestamp, schema_any_fresh):
405405
assert config["add_hidden_timestamp"], "add_hidden_timestamp is not enabled"
406406
msg = f"{Experiment().heading._attributes=}"
407407
assert any(
@@ -414,7 +414,7 @@ def test_add_hidden_timestamp_enabled(enable_add_hidden_timestamp, schema_any):
414414
assert not any(a.is_hidden for a in Experiment().heading.attributes.values()), msg
415415

416416

417-
def test_add_hidden_timestamp_disabled(disable_add_hidden_timestamp, schema_any):
417+
def test_add_hidden_timestamp_disabled(disable_add_hidden_timestamp, schema_any_fresh):
418418
assert not config[
419419
"add_hidden_timestamp"
420420
], "expected add_hidden_timestamp to be False"

tests/test_jobs.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from . import schema
1010

1111

12-
def test_reserve_job(subject, schema_any):
12+
def test_reserve_job(clean_jobs, subject, schema_any):
1313
assert subject
1414
table_name = "fake_table"
1515

@@ -47,7 +47,7 @@ def test_reserve_job(subject, schema_any):
4747
assert not schema_any.jobs, "failed to clear error jobs"
4848

4949

50-
def test_restrictions(schema_any):
50+
def test_restrictions(clean_jobs, schema_any):
5151
jobs = schema_any.jobs
5252
jobs.delete()
5353
jobs.reserve("a", {"key": "a1"})
@@ -62,7 +62,7 @@ def test_restrictions(schema_any):
6262
jobs.delete()
6363

6464

65-
def test_sigint(schema_any):
65+
def test_sigint(clean_jobs, schema_any):
6666
try:
6767
schema.SigIntTable().populate(reserve_jobs=True)
6868
except KeyboardInterrupt:
@@ -74,7 +74,7 @@ def test_sigint(schema_any):
7474
assert error_message == "KeyboardInterrupt"
7575

7676

77-
def test_sigterm(schema_any):
77+
def test_sigterm(clean_jobs, schema_any):
7878
try:
7979
schema.SigTermTable().populate(reserve_jobs=True)
8080
except SystemExit:
@@ -86,14 +86,14 @@ def test_sigterm(schema_any):
8686
assert error_message == "SystemExit: SIGTERM received"
8787

8888

89-
def test_suppress_dj_errors(schema_any):
89+
def test_suppress_dj_errors(clean_jobs, schema_any):
9090
"""test_suppress_dj_errors: dj errors suppressible w/o native py blobs"""
9191
with dj.config(enable_python_native_blobs=False):
9292
schema.ErrorClass.populate(reserve_jobs=True, suppress_errors=True)
9393
assert len(schema.DjExceptionName()) == len(schema_any.jobs) > 0
9494

9595

96-
def test_long_error_message(subject, schema_any):
96+
def test_long_error_message(clean_jobs, subject, schema_any):
9797
# create long error message
9898
long_error_message = "".join(
9999
random.choice(string.ascii_letters) for _ in range(ERROR_MESSAGE_LENGTH + 100)
@@ -129,7 +129,7 @@ def test_long_error_message(subject, schema_any):
129129
schema_any.jobs.delete()
130130

131131

132-
def test_long_error_stack(subject, schema_any):
132+
def test_long_error_stack(clean_jobs, subject, schema_any):
133133
# create long error stack
134134
STACK_SIZE = (
135135
89942 # Does not fit into small blob (should be 64k, but found to be higher)

0 commit comments

Comments
 (0)