Skip to content

Commit f9397ee

Browse files
authored
chore: Merge pull request #200 from swerik-project/fix_mp_test
Fix: MP test rewamp, add baseline errors and make sure that test runs correctly.
2 parents 2ff1a92 + c634bf6 commit f9397ee

7 files changed

Lines changed: 2934 additions & 82 deletions

File tree

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
protocol_id,person_id,protocol_dates,born,dead,speaker_intro_text

test/data/mp/baseline-dead-mps.csv

Lines changed: 277 additions & 0 deletions
Large diffs are not rendered by default.

test/data/mp/baseline-missing-persons.csv

Lines changed: 1072 additions & 0 deletions
Large diffs are not rendered by default.

test/mp.py

Lines changed: 234 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,103 +1,255 @@
1-
import unittest
2-
import pandas as pd
3-
from lxml import etree
4-
from pyriksdagen.utils import validate_xml_schema, infer_metadata, get_data_location
5-
from pyriksdagen.db import load_patterns, filter_db, load_ministers, load_metadata
1+
"""
2+
Test suite for validating Swedish parliamentary protocol data.
3+
4+
Checks:
5+
- MPs appearing in protocols outside their mandate periods
6+
- MPs appearing in protocols after death
7+
- MPs appearing in protocols before age 15
8+
9+
Uses baseline CSV files to compare error counts via confidence intervals (CI).
10+
Logs results using trainerlog.
11+
"""
12+
13+
from datetime import datetime
614
from pathlib import Path
15+
from pyriksdagen.db import load_metadata
16+
from pyriksdagen.io import parse_tei
17+
from pyriksdagen.utils import get_doc_dates
18+
from trainerlog import get_logger
19+
20+
import calendar
21+
import math
22+
import os
23+
import pandas as pd
724
import progressbar
8-
import warnings
25+
import unittest
926

27+
logger = get_logger(name="mp-test")
28+
29+
def parse_date_start(s):
30+
"""Parse a string into a start datetime. Returns 1800-01-01 if missing."""
31+
32+
if pd.isna(s) or str(s).strip() == "":
33+
return datetime(1800, 1, 1)
34+
try:
35+
dt = pd.to_datetime(s, errors="coerce")
36+
if pd.notna(dt):
37+
return dt
38+
parts = str(s).split("-")
39+
year = int(parts[0])
40+
month = int(parts[1]) if len(parts) > 1 else 1
41+
day = int(parts[2]) if len(parts) > 2 else 1
42+
return datetime(year, month, day)
43+
except:
44+
return None
45+
46+
47+
def parse_date_end(s):
48+
"""Parse a string into an end datetime. Returns max datetime if missing."""
49+
if pd.isna(s) or str(s).strip() == "":
50+
return pd.Timestamp.max.to_pydatetime()
51+
try:
52+
dt = pd.to_datetime(s, errors="coerce")
53+
if pd.notna(dt):
54+
return dt
55+
parts = str(s).split("-")
56+
year = int(parts[0])
57+
month = int(parts[1]) if len(parts) > 1 else 12
58+
day = int(parts[2]) if len(parts) > 2 else calendar.monthrange(year, month)[1]
59+
return datetime(year, month, day)
60+
except:
61+
return None
62+
63+
64+
def assert_ci(baseline_file, new_df, confidence=0.95):
65+
"""
66+
Compare the number of errors in a DataFrame against a baseline CSV using a confidence interval.
67+
Logs info, warnings, or errors and raises AssertionError if outside CI.
68+
"""
69+
df = pd.read_csv(baseline_file)
70+
71+
ci_low = len(df) - 2*math.sqrt(len(df))
72+
ci_high = len(df) + 2*math.sqrt(len(df))
73+
74+
new_count = len(new_df)
75+
76+
logger.info(f"Baseline error count: {len(df)}")
77+
logger.info(f"Allowed error count within {int(confidence*100)}% CI: [{ci_low:.0f}, {ci_high:.0f}]")
78+
logger.info(f"New error count: {new_count}")
79+
80+
if ci_low <= new_count <= ci_high:
81+
status = "inside"
82+
elif new_count > ci_high:
83+
status = "above"
84+
else:
85+
status = "below"
86+
87+
if status == "inside":
88+
mid_ci = ci_low + (ci_high - ci_low)/2
89+
if new_count > mid_ci:
90+
logger.warning(f"Error count {new_count} increased but remains within CI [{ci_low:.0f}, {ci_high:.0f}].")
91+
elif new_count == mid_ci:
92+
logger.info(f"Error count {new_count} is at midpoint of CI; no change detected.")
93+
else:
94+
logger.warning(f"Error count {new_count} decreased but remains within CI [{ci_low:.0f}, {ci_high:.0f}].")
95+
elif status == "above":
96+
logger.error(f"Error count {new_count} exceeds upper CI bound [{ci_low:.0f}, {ci_high:.0f}]!")
97+
raise AssertionError(f"Error count {new_count} exceeds upper CI bound [{ci_low:.0f}, {ci_high:.0f}]!")
98+
else:
99+
logger.error(f"Error count {new_count} falls below lower CI bound [{ci_low:.0f}, {ci_high:.0f}]!")
100+
raise AssertionError(f"Error count {new_count} falls below lower CI bound [{ci_low:.0f}, {ci_high:.0f}]!")
101+
102+
103+
def aggregate_dates(df_subset):
104+
"""
105+
Aggregate protocol dates, birth, death, and speaker info for each protocol-person combination.
106+
Returns a DataFrame with combined information.
107+
"""
108+
return (
109+
df_subset.groupby(["protocol_id","person_id"])
110+
.agg(
111+
protocol_dates=("protocol_date",
112+
lambda x: ";".join(sorted(x.dt.strftime("%Y-%m-%d").unique()))
113+
),
114+
born=("born","first"),
115+
dead=("dead","first"),
116+
speaker_intro_text=("speaker_intro_text","first")
117+
)
118+
.reset_index()
119+
)
10120

11121
class Test(unittest.TestCase):
12122

13-
# Official example parla-clarin
14123
def test_protocol(self):
15-
parser = etree.XMLParser(remove_blank_text=True)
16-
17-
def test_one_protocol(root, mp_ids, mp_db):
18-
found = True
19-
years = []
20-
date = None
21-
for docDate in root.findall(".//{http://www.tei-c.org/ns/1.0}docDate"):
22-
docDateYear = docDate.attrib.get("when", "unknown")
23-
date = docDateYear
24-
docDateYear = int(docDateYear.split("-")[0])
25-
years.append(docDateYear)
26-
27-
for year in years:
28-
if year not in mp_ids:
29-
year_db = filter_db(mp_db, year=year)
30-
ids = list(year_db["id"])
31-
mp_ids[year] = ids
32-
33-
false_whos = []
34-
whos = set()
35-
for body in root.findall(".//{http://www.tei-c.org/ns/1.0}body"):
36-
for div in body.findall("{http://www.tei-c.org/ns/1.0}div"):
37-
for ix, elem in enumerate(div):
38-
if elem.tag == "{http://www.tei-c.org/ns/1.0}u":
39-
who = elem.attrib.get("who", "unknown")
40-
if who != "unknown":
41-
whos.add(who)
42-
elem_found = False
43-
for year in years:
44-
if who in mp_ids[year]:
45-
elem_found = True
46-
47-
if not elem_found:
48-
found = False
49-
false_whos.append(who)
50-
51-
# Check for dead or child speakers
52-
dead_whos = []
53-
child_whos = []
54-
mp_doa = mp_db[['id', 'born', 'dead']].drop_duplicates().reset_index(drop=True)
55-
mp_doa['born'] = mp_doa['born'].fillna('0000')
56-
mp_doa['dead'] = mp_doa['dead'].fillna('9999')
57-
58-
fronts = root.findall(".//{http://www.tei-c.org/ns/1.0}front")
59-
heads = fronts[0].findall(".//{http://www.tei-c.org/ns/1.0}head")
60-
for who in whos:
61-
mp = mp_doa.loc[mp_doa['id'] == who]
62-
63-
warning_text = f"Speaker {who} not found in db. Protocol {heads[0].text}"
64-
self.assertGreaterEqual(len(mp), 1, warning_text)
65-
66-
born = min(mp['born'].apply(lambda x: int(x[:4])))
67-
dead = max(mp['dead'].apply(lambda x: int(x[:4])))
68-
if max(years) > dead:
69-
dead_whos.append(who)
70-
if max(years) < born + 15:
71-
child_whos.append(who)
72-
73-
return found, false_whos, dead_whos, child_whos
74-
75-
# new
124+
76125
folder = "data"
77126
*_, mp_db, minister_db, speaker_db = load_metadata()
127+
78128
mp_db = pd.concat([mp_db, minister_db, speaker_db])
129+
mp_db["start"] = mp_db["start"].map(parse_date_start)
130+
mp_db["end"] = mp_db["end"].map(parse_date_end)
131+
132+
mp_doa = mp_db[["id", "born", "dead"]].drop_duplicates().reset_index(drop=True)
133+
mp_doa["born"] = pd.to_datetime(mp_doa["born"], errors="coerce")
134+
mp_doa["dead"] = pd.to_datetime(mp_doa["dead"], errors="coerce")
135+
mp_doa = mp_doa.rename(columns={"id": "person_id"})
79136

80-
mp_ids = {}
137+
records = []
81138

82-
failed_protocols = []
139+
# Find the dates, people and intros in protocols
83140
for outfolder in progressbar.progressbar(list(Path(folder).glob("*/"))):
84141
for protocol_path in outfolder.glob("*.xml"):
142+
85143
protocol_id = protocol_path.stem
86-
path_str = str(protocol_path.resolve())
87-
root = etree.parse(path_str, parser).getroot()
88-
found, false_whos, dead_whos, child_whos = test_one_protocol(root, mp_ids, mp_db)
89-
if not found:
90-
failed_protocols.append(protocol_id + " (" + false_whos[0] + ")")
144+
root, ns = parse_tei(str(protocol_path))
145+
146+
match_error, dates = get_doc_dates(root)
147+
protocol_dates = [pd.to_datetime(d, errors="coerce") for d in dates if d]
148+
149+
if not protocol_dates:
150+
continue
151+
152+
last_speaker_intro = ""
153+
154+
for body in root.findall(f".//{ns['tei_ns']}body"):
155+
for div in body.findall(f"{ns['tei_ns']}div"):
156+
for elem in div:
157+
158+
if (
159+
elem.tag == f"{ns['tei_ns']}note"
160+
and elem.attrib.get("type") == "speaker"
161+
):
162+
if elem.text:
163+
last_speaker_intro = elem.text.strip()
164+
165+
elif elem.tag == f"{ns['tei_ns']}u":
166+
167+
who = elem.attrib.get("who")
168+
169+
if not who or who == "unknown":
170+
continue
171+
172+
for pdate in protocol_dates:
173+
records.append({"protocol_id": protocol_id, "person_id": who, "protocol_date": pdate, "speaker_intro_text": last_speaker_intro})
174+
175+
df = pd.DataFrame(records)
176+
177+
relevant_ids = df["person_id"].unique()
178+
mp_subset = mp_db[mp_db["id"].isin(relevant_ids)][["id", "start", "end"]]
179+
180+
df_merged = df.merge(mp_subset, how="left", left_on="person_id", right_on="id")
181+
df_merged["valid_mandate"] = ((df_merged["protocol_date"] >= df_merged["start"]) & (df_merged["protocol_date"] <= df_merged["end"]))
182+
183+
# For each protocol/person, check if ANY protocol_date matches a mandate
184+
mandate_check = (df_merged.groupby(["protocol_id","person_id"])["valid_mandate"].any().reset_index())
185+
186+
df_fail = mandate_check[~mandate_check["valid_mandate"]].drop(columns="valid_mandate")
187+
188+
df_fail_dates = (
189+
df.groupby(["protocol_id", "person_id"])["protocol_date"]
190+
.apply(
191+
lambda x: ";".join(
192+
sorted(x.dt.strftime("%Y-%m-%d").unique())
193+
)
194+
)
195+
.reset_index(name="protocol_dates")
196+
)
197+
198+
df_fail = df_fail.merge(df_fail_dates, on=["protocol_id", "person_id"])
199+
200+
df_intro = df[["protocol_id", "person_id", "speaker_intro_text"]].drop_duplicates()
201+
202+
df_fail = df_fail.merge(df_intro, on=["protocol_id", "person_id"], how="left")
203+
204+
# Child-dead checks
205+
df = df.merge(mp_doa, on="person_id", how="left")
206+
207+
df_dead = df[(df["dead"].notna()) & (df["protocol_date"] > df["dead"])].copy()
208+
df_child = df[(df["born"].notna()) & (df["protocol_date"] < df["born"] + pd.DateOffset(years=15))].copy()
209+
210+
df_dead = aggregate_dates(df_dead)
211+
df_child = aggregate_dates(df_child)
212+
213+
df_fail = df_fail.drop_duplicates(subset=["protocol_id","person_id"])
214+
df_dead = df_dead.drop_duplicates(subset=["protocol_id","person_id"])
215+
df_child = df_child.drop_duplicates(subset=["protocol_id","person_id"])
216+
217+
# Write results
218+
results_dir = "test/results"
219+
os.makedirs(results_dir, exist_ok=True)
220+
221+
df_fail.to_csv(f"{results_dir}/missing-persons.csv", index=False)
222+
df_dead.to_csv(f"{results_dir}/dead-mps.csv", index=False)
223+
df_child.to_csv(f"{results_dir}/child-mps.csv", index=False)
224+
225+
# Check changes
226+
failures = []
227+
228+
baseline_dir = "test/data/mp"
91229

92-
print("Protocols with inactive MPs tagged as speakers:", ", ".join(failed_protocols))
93-
print("Dead MPs tagged as speakers:", ", ".join(dead_whos))
94-
print("Children MPs tagged as speakers:", ", ".join(child_whos))
230+
try:
231+
logger.info(f"=== Checking {len(df_fail)} errors for MP presence in protocol vs MP mandate periods ===")
232+
assert_ci(f"{baseline_dir}/baseline-missing-persons.csv", df_fail)
233+
logger.info("")
234+
except AssertionError as e:
235+
failures.append(str(e))
95236

96-
self.assertEqual(len(failed_protocols), 0)
237+
try:
238+
logger.info(f"=== Checking {len(df_dead)} errors for MPs appearing in protocol after death ===")
239+
assert_ci(f"{baseline_dir}/baseline-dead-mps.csv", df_dead)
240+
logger.info("")
241+
except AssertionError as e:
242+
failures.append(str(e))
97243

244+
try:
245+
logger.info(f"=== Checking {len(df_child)} errors for MPs appearing in protocol before age 15 ===")
246+
assert_ci(f"{baseline_dir}/baseline-child-mps.csv", df_child)
247+
except AssertionError as e:
248+
failures.append(str(e))
98249

250+
if failures:
251+
raise AssertionError("\n".join(failures))
99252

100253

101-
if __name__ == '__main__':
102-
# begin the unittest.main()
103-
unittest.main()
254+
if __name__ == "__main__":
255+
unittest.main()

test/results/child-mps.csv

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
protocol_id,person_id,protocol_dates,born,dead,speaker_intro_text

0 commit comments

Comments
 (0)