Skip to content

Commit 7392a6f

Browse files
Merge pull request #73 from UBC-MDS/yonas.test-fix
Enhance test documentaion for function compare contracts
2 parents 42b9794 + ec244be commit 7392a6f

File tree

1 file changed

+125
-24
lines changed

1 file changed

+125
-24
lines changed

tests/unit/test_compare_contracts.py

Lines changed: 125 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
"""
2-
Basic unit tests for compare_contracts.
2+
Unit tests for compare_contracts.
3+
4+
The compare_contracts function should:
5+
- Identify schema drift (added/removed columns, dtype changes).
6+
- Identify constraint drift (range, category, missingness changes).
7+
- Validate input contracts and raise on invalid rules.
8+
- Report drift in a consistent shape and ordering.
39
"""
410

511
import pytest
@@ -9,30 +15,44 @@
915

1016

1117
def test_added_and_removed_columns():
12-
"""Detect added and removed columns between contracts."""
18+
"""Detects added/removed columns, which would break downstream schema.
19+
20+
We swap the single column name between contracts to ensure both the
21+
addition and removal are reported in the correct direction.
22+
"""
1323
contract_a = Contract(columns={"age": ColumnRule(dtype="int")})
1424
contract_b = Contract(columns={"height": ColumnRule(dtype="int")})
1525

1626
report = compare_contracts(contract_a, contract_b)
1727

28+
# New column appears only in contract_b; old column only in contract_a.
1829
assert report.added_columns == {"height"}
1930
assert report.removed_columns == {"age"}
2031
assert report.has_drift is True
2132

2233

2334
def test_dtype_change():
24-
"""Detect dtype changes for shared columns."""
35+
"""Detects dtype changes for shared columns, a common schema break.
36+
37+
We keep the column name but change its dtype to verify the change is
38+
reported as an (old, new) pair.
39+
"""
2540
contract_a = Contract(columns={"age": ColumnRule(dtype="int")})
2641
contract_b = Contract(columns={"age": ColumnRule(dtype="float")})
2742

2843
report = compare_contracts(contract_a, contract_b)
2944

45+
# Dtype changes are reported as (old, new).
3046
assert report.dtype_changes == {"age": ("int", "float")}
3147
assert report.has_drift is True
3248

3349

3450
def test_range_change():
35-
"""Detect min/max range changes for shared columns."""
51+
"""Detects numeric range drift when dtype stays the same.
52+
53+
We change the min bound for a float column and verify the column is
54+
flagged in the range_changes set.
55+
"""
3656
contract_a = Contract(
3757
columns={"score": ColumnRule(dtype="float", min_value=0.0, max_value=1.0)}
3858
)
@@ -42,12 +62,17 @@ def test_range_change():
4262

4363
report = compare_contracts(contract_a, contract_b)
4464

65+
# Same dtype but different bounds should trigger range drift.
4566
assert report.range_changes == {"score"}
4667
assert report.has_drift is True
4768

4869

4970
def test_category_and_missingness_changes():
50-
"""Detect category and missingness changes for shared columns."""
71+
"""Detects category and missingness drift for a shared categorical column.
72+
73+
We expand allowed values and relax max_missing_frac to confirm both
74+
change types are captured in the report.
75+
"""
5176
contract_a = Contract(
5277
columns={
5378
"status": ColumnRule(
@@ -67,13 +92,18 @@ def test_category_and_missingness_changes():
6792

6893
report = compare_contracts(contract_a, contract_b)
6994

95+
# Allowed values expansion and missingness change are recorded independently.
7096
assert report.category_changes == {"status"}
7197
assert report.missingness_changes == {"status": (0.05, 0.10)}
7298
assert report.has_drift is True
7399

74100

75101
def test_missingness_only_change():
76-
"""Detect missingness drift without other changes."""
102+
"""Detects missingness drift without implying other drift types.
103+
104+
We change only max_missing_frac and confirm dtype/range/category drift
105+
remain empty.
106+
"""
77107
contract_a = Contract(
78108
columns={"age": ColumnRule(dtype="int", max_missing_frac=0.05)}
79109
)
@@ -83,14 +113,19 @@ def test_missingness_only_change():
83113

84114
report = compare_contracts(contract_a, contract_b)
85115

116+
# Only missingness should change; all other categories remain empty.
86117
assert report.missingness_changes == {"age": (0.05, 0.10)}
87118
assert report.dtype_changes == {}
88119
assert report.range_changes == set()
89120
assert report.category_changes == set()
90121

91122

92123
def test_category_only_change():
93-
"""Detect category drift without other changes."""
124+
"""Detects category drift without implying other drift types.
125+
126+
We expand allowed_values and confirm dtype/range/missingness drift
127+
stay empty.
128+
"""
94129
contract_a = Contract(
95130
columns={"status": ColumnRule(dtype="category", allowed_values={"new", "old"})}
96131
)
@@ -104,14 +139,19 @@ def test_category_only_change():
104139

105140
report = compare_contracts(contract_a, contract_b)
106141

142+
# Only categories should change; other buckets remain empty.
107143
assert report.category_changes == {"status"}
108144
assert report.dtype_changes == {}
109145
assert report.range_changes == set()
110146
assert report.missingness_changes == {}
111147

112148

113149
def test_range_change_none_to_value():
114-
"""None to value for range triggers drift."""
150+
"""Detects range drift when optional bounds become defined.
151+
152+
We move min/max from None to numeric values to verify the column is
153+
flagged as a range change.
154+
"""
115155
contract_a = Contract(
116156
columns={"score": ColumnRule(dtype="float", min_value=None, max_value=None)}
117157
)
@@ -121,11 +161,16 @@ def test_range_change_none_to_value():
121161

122162
report = compare_contracts(contract_a, contract_b)
123163

164+
# Optional bounds become defined, which is a drift.
124165
assert report.range_changes == {"score"}
125166

126167

127168
def test_category_change_none_to_set():
128-
"""None to set for categories triggers drift."""
169+
"""Detects category drift when allowed_values becomes defined.
170+
171+
We move allowed_values from None to a set to confirm category drift
172+
is reported.
173+
"""
129174
contract_a = Contract(
130175
columns={"status": ColumnRule(dtype="category", allowed_values=None)}
131176
)
@@ -135,11 +180,16 @@ def test_category_change_none_to_set():
135180

136181
report = compare_contracts(contract_a, contract_b)
137182

183+
# Optional category set becomes defined, which is a drift.
138184
assert report.category_changes == {"status"}
139185

140186

141187
def test_none_to_none_no_drift():
142-
"""None to None for optional fields yields no drift."""
188+
"""Confirms no drift when optional fields stay None.
189+
190+
We keep min/max as None in both contracts and verify the report does
191+
not flag a range or category change.
192+
"""
143193
contract_a = Contract(
144194
columns={"score": ColumnRule(dtype="float", min_value=None, max_value=None)}
145195
)
@@ -149,12 +199,16 @@ def test_none_to_none_no_drift():
149199

150200
report = compare_contracts(contract_a, contract_b)
151201

202+
# Optional fields unchanged -> no drift reported.
152203
assert report.range_changes == set()
153204
assert report.category_changes == set()
154205

155206

156207
def test_dtype_change_blocks_range_drift():
157-
"""Range drift is not reported when dtype changes."""
208+
"""Ensures dtype changes suppress range drift for that column.
209+
210+
We change dtype and bounds to confirm only dtype drift is reported.
211+
"""
158212
contract_a = Contract(
159213
columns={"score": ColumnRule(dtype="int", min_value=0.0, max_value=10.0)}
160214
)
@@ -164,12 +218,16 @@ def test_dtype_change_blocks_range_drift():
164218

165219
report = compare_contracts(contract_a, contract_b)
166220

221+
# Dtype change is reported; range changes are ignored when dtype differs.
167222
assert report.dtype_changes == {"score": ("int", "float")}
168223
assert report.range_changes == set()
169224

170225

171226
def test_dtype_change_blocks_category_drift():
172-
"""Category drift is not reported when dtype changes."""
227+
"""Ensures dtype changes suppress category drift for that column.
228+
229+
We change dtype and allowed_values to verify only dtype drift appears.
230+
"""
173231
contract_a = Contract(
174232
columns={"status": ColumnRule(dtype="category", allowed_values={"new", "old"})}
175233
)
@@ -183,20 +241,27 @@ def test_dtype_change_blocks_category_drift():
183241

184242
report = compare_contracts(contract_a, contract_b)
185243

244+
# Dtype change is reported; category changes are ignored when dtype differs.
186245
assert report.dtype_changes == {"status": ("category", "string")}
187246
assert report.category_changes == set()
188247

189248

190249
def test_invalid_contract_type_raises_typeerror():
191-
"""Invalid contract types raise TypeError."""
250+
"""Validates inputs: non-Contract arguments should raise TypeError.
251+
252+
We pass a string in place of contract_a to ensure the guard triggers.
253+
"""
192254
contract_b = Contract(columns={"age": ColumnRule(dtype="int")})
193255

194256
with pytest.raises(TypeError):
195257
compare_contracts("not-a-contract", contract_b)
196258

197259

198260
def test_invalid_missing_frac_raises_valueerror():
199-
"""Invalid missingness fraction raises ValueError."""
261+
"""Validates rule constraints: max_missing_frac must be within [0, 1].
262+
263+
We set max_missing_frac above 1.0 to confirm the validation raises.
264+
"""
200265
contract_a = Contract(
201266
columns={"age": ColumnRule(dtype="int", max_missing_frac=1.5)}
202267
)
@@ -207,7 +272,10 @@ def test_invalid_missing_frac_raises_valueerror():
207272

208273

209274
def test_non_columnrule_raises_typeerror():
210-
"""Non-ColumnRule entries raise TypeError."""
275+
"""Validates inputs: every column must map to a ColumnRule instance.
276+
277+
We inject a plain string for a column rule to ensure TypeError is raised.
278+
"""
211279
contract_a = Contract(columns={"age": "not-a-rule"})
212280
contract_b = Contract(columns={"age": ColumnRule(dtype="int")})
213281

@@ -216,7 +284,10 @@ def test_non_columnrule_raises_typeerror():
216284

217285

218286
def test_non_numeric_missing_frac_raises_valueerror():
219-
"""Non-numeric max_missing_frac raises ValueError."""
287+
"""Validates rule constraints: max_missing_frac must be numeric.
288+
289+
We pass a non-numeric value and confirm the validation raises.
290+
"""
220291
contract_a = Contract(
221292
columns={"age": ColumnRule(dtype="int", max_missing_frac="high")}
222293
)
@@ -227,7 +298,11 @@ def test_non_numeric_missing_frac_raises_valueerror():
227298

228299

229300
def test_invalid_contract_b_raises_valueerror():
230-
"""Invalid contract_b triggers validation on the second contract."""
301+
"""Validates both contracts: contract_b must pass the same checks.
302+
303+
We put an invalid missingness fraction in contract_b to ensure its
304+
validation is not skipped.
305+
"""
231306
contract_a = Contract(columns={"age": ColumnRule(dtype="int")})
232307
contract_b = Contract(
233308
columns={"age": ColumnRule(dtype="int", max_missing_frac=2.0)}
@@ -238,7 +313,10 @@ def test_invalid_contract_b_raises_valueerror():
238313

239314

240315
def test_min_greater_than_max_raises_valueerror():
241-
"""Invalid range bounds raise ValueError."""
316+
"""Validates numeric bounds: min_value cannot exceed max_value.
317+
318+
We set min_value > max_value to confirm the guard raises ValueError.
319+
"""
242320
contract_a = Contract(
243321
columns={"age": ColumnRule(dtype="int", min_value=10.0, max_value=1.0)}
244322
)
@@ -278,13 +356,20 @@ def test_min_greater_than_max_raises_valueerror():
278356
],
279357
)
280358
def test_has_drift_true_for_any_nonempty_change(contract_a, contract_b):
281-
"""Any non-empty drift category flips has_drift to True."""
359+
"""Reports has_drift True for any non-empty drift category.
360+
361+
We parametrize over each drift type and confirm has_drift flips to True
362+
whenever at least one drift bucket is populated.
363+
"""
282364
report = compare_contracts(contract_a, contract_b)
283365
assert report.has_drift is True
284366

285367

286368
def test_has_drift_false_for_no_drift():
287-
"""No drift yields has_drift False."""
369+
"""Reports has_drift False when contracts are identical.
370+
371+
We compare two identical contracts to ensure the summary flag is False.
372+
"""
288373
contract_a = Contract(columns={"age": ColumnRule(dtype="int")})
289374
contract_b = Contract(columns={"age": ColumnRule(dtype="int")})
290375

@@ -294,7 +379,11 @@ def test_has_drift_false_for_no_drift():
294379

295380

296381
def test_multiple_columns_mixed_drift():
297-
"""Aggregate drift across multiple columns."""
382+
"""Aggregates drift across multiple columns and drift types.
383+
384+
We introduce one change per drift category to confirm all are captured
385+
in a single report.
386+
"""
298387
contract_a = Contract(
299388
columns={
300389
"a": ColumnRule(dtype="int"),
@@ -315,6 +404,7 @@ def test_multiple_columns_mixed_drift():
315404

316405
report = compare_contracts(contract_a, contract_b)
317406

407+
# Each drift category should reflect the intended column(s).
318408
assert report.added_columns == {"e"}
319409
assert report.removed_columns == set()
320410
assert report.dtype_changes == {"a": ("int", "float")}
@@ -324,7 +414,11 @@ def test_multiple_columns_mixed_drift():
324414

325415

326416
def test_missingness_reports_old_new_order():
327-
"""Missingness changes are reported as (old, new)."""
417+
"""Confirms missingness changes are ordered as (old, new).
418+
419+
We increase max_missing_frac and verify the tuple preserves old-to-new
420+
ordering for downstream reporting.
421+
"""
328422
contract_a = Contract(
329423
columns={"age": ColumnRule(dtype="int", max_missing_frac=0.05)}
330424
)
@@ -338,7 +432,10 @@ def test_missingness_reports_old_new_order():
338432

339433

340434
def test_dtype_reports_old_new_order():
341-
"""Dtype changes are reported as (old, new)."""
435+
"""Confirms dtype changes are ordered as (old, new).
436+
437+
We change dtype and ensure the report preserves the old-to-new ordering.
438+
"""
342439
contract_a = Contract(columns={"age": ColumnRule(dtype="int")})
343440
contract_b = Contract(columns={"age": ColumnRule(dtype="float")})
344441

@@ -348,7 +445,11 @@ def test_dtype_reports_old_new_order():
348445

349446

350447
def test_no_drift():
351-
"""No differences yields an empty report."""
448+
"""Confirms no drift yields empty buckets and has_drift False.
449+
450+
We keep all fields identical to verify each drift bucket is empty and
451+
the summary flag remains False.
452+
"""
352453
contract_a = Contract(
353454
columns={"age": ColumnRule(dtype="int", min_value=0.0, max_value=120.0)}
354455
)

0 commit comments

Comments
 (0)