1- import unittest
2- import pandas as pd
3- from lxml import etree
4- from pyriksdagen .utils import validate_xml_schema , infer_metadata , get_data_location
5- from pyriksdagen .db import load_patterns , filter_db , load_ministers , load_metadata
1+ """
2+ Test suite for validating Swedish parliamentary protocol data.
3+
4+ Checks:
5+ - MPs appearing in protocols outside their mandate periods
6+ - MPs appearing in protocols after death
7+ - MPs appearing in protocols before age 15
8+
9+ Uses baseline CSV files to compare error counts via confidence intervals (CI).
10+ Logs results using trainerlog.
11+ """
12+
13+ from datetime import datetime
614from pathlib import Path
15+ from pyriksdagen .db import load_metadata
16+ from pyriksdagen .io import parse_tei
17+ from pyriksdagen .utils import get_doc_dates
18+ from trainerlog import get_logger
19+
20+ import calendar
21+ import math
22+ import os
23+ import pandas as pd
724import progressbar
8- import warnings
25+ import unittest
926
27+ logger = get_logger (name = "mp-test" )
28+
29+ def parse_date_start (s ):
30+ """Parse a string into a start datetime. Returns 1800-01-01 if missing."""
31+
32+ if pd .isna (s ) or str (s ).strip () == "" :
33+ return datetime (1800 , 1 , 1 )
34+ try :
35+ dt = pd .to_datetime (s , errors = "coerce" )
36+ if pd .notna (dt ):
37+ return dt
38+ parts = str (s ).split ("-" )
39+ year = int (parts [0 ])
40+ month = int (parts [1 ]) if len (parts ) > 1 else 1
41+ day = int (parts [2 ]) if len (parts ) > 2 else 1
42+ return datetime (year , month , day )
43+ except :
44+ return None
45+
46+
47+ def parse_date_end (s ):
48+ """Parse a string into an end datetime. Returns max datetime if missing."""
49+ if pd .isna (s ) or str (s ).strip () == "" :
50+ return pd .Timestamp .max .to_pydatetime ()
51+ try :
52+ dt = pd .to_datetime (s , errors = "coerce" )
53+ if pd .notna (dt ):
54+ return dt
55+ parts = str (s ).split ("-" )
56+ year = int (parts [0 ])
57+ month = int (parts [1 ]) if len (parts ) > 1 else 12
58+ day = int (parts [2 ]) if len (parts ) > 2 else calendar .monthrange (year , month )[1 ]
59+ return datetime (year , month , day )
60+ except :
61+ return None
62+
63+
64+ def assert_ci (baseline_file , new_df , confidence = 0.95 ):
65+ """
66+ Compare the number of errors in a DataFrame against a baseline CSV using a confidence interval.
67+ Logs info, warnings, or errors and raises AssertionError if outside CI.
68+ """
69+ df = pd .read_csv (baseline_file )
70+
71+ ci_low = len (df ) - 2 * math .sqrt (len (df ))
72+ ci_high = len (df ) + 2 * math .sqrt (len (df ))
73+
74+ new_count = len (new_df )
75+
76+ logger .info (f"Baseline error count: { len (df )} " )
77+ logger .info (f"Allowed error count within { int (confidence * 100 )} % CI: [{ ci_low :.0f} , { ci_high :.0f} ]" )
78+ logger .info (f"New error count: { new_count } " )
79+
80+ if ci_low <= new_count <= ci_high :
81+ status = "inside"
82+ elif new_count > ci_high :
83+ status = "above"
84+ else :
85+ status = "below"
86+
87+ if status == "inside" :
88+ mid_ci = ci_low + (ci_high - ci_low )/ 2
89+ if new_count > mid_ci :
90+ logger .warning (f"Error count { new_count } increased but remains within CI [{ ci_low :.0f} , { ci_high :.0f} ]." )
91+ elif new_count == mid_ci :
92+ logger .info (f"Error count { new_count } is at midpoint of CI; no change detected." )
93+ else :
94+ logger .warning (f"Error count { new_count } decreased but remains within CI [{ ci_low :.0f} , { ci_high :.0f} ]." )
95+ elif status == "above" :
96+ logger .error (f"Error count { new_count } exceeds upper CI bound [{ ci_low :.0f} , { ci_high :.0f} ]!" )
97+ raise AssertionError (f"Error count { new_count } exceeds upper CI bound [{ ci_low :.0f} , { ci_high :.0f} ]!" )
98+ else :
99+ logger .error (f"Error count { new_count } falls below lower CI bound [{ ci_low :.0f} , { ci_high :.0f} ]!" )
100+ raise AssertionError (f"Error count { new_count } falls below lower CI bound [{ ci_low :.0f} , { ci_high :.0f} ]!" )
101+
102+
103+ def aggregate_dates (df_subset ):
104+ """
105+ Aggregate protocol dates, birth, death, and speaker info for each protocol-person combination.
106+ Returns a DataFrame with combined information.
107+ """
108+ return (
109+ df_subset .groupby (["protocol_id" ,"person_id" ])
110+ .agg (
111+ protocol_dates = ("protocol_date" ,
112+ lambda x : ";" .join (sorted (x .dt .strftime ("%Y-%m-%d" ).unique ()))
113+ ),
114+ born = ("born" ,"first" ),
115+ dead = ("dead" ,"first" ),
116+ speaker_intro_text = ("speaker_intro_text" ,"first" )
117+ )
118+ .reset_index ()
119+ )
10120
11121class Test (unittest .TestCase ):
12122
13- # Official example parla-clarin
14123 def test_protocol (self ):
15- parser = etree .XMLParser (remove_blank_text = True )
16-
17- def test_one_protocol (root , mp_ids , mp_db ):
18- found = True
19- years = []
20- date = None
21- for docDate in root .findall (".//{http://www.tei-c.org/ns/1.0}docDate" ):
22- docDateYear = docDate .attrib .get ("when" , "unknown" )
23- date = docDateYear
24- docDateYear = int (docDateYear .split ("-" )[0 ])
25- years .append (docDateYear )
26-
27- for year in years :
28- if year not in mp_ids :
29- year_db = filter_db (mp_db , year = year )
30- ids = list (year_db ["id" ])
31- mp_ids [year ] = ids
32-
33- false_whos = []
34- whos = set ()
35- for body in root .findall (".//{http://www.tei-c.org/ns/1.0}body" ):
36- for div in body .findall ("{http://www.tei-c.org/ns/1.0}div" ):
37- for ix , elem in enumerate (div ):
38- if elem .tag == "{http://www.tei-c.org/ns/1.0}u" :
39- who = elem .attrib .get ("who" , "unknown" )
40- if who != "unknown" :
41- whos .add (who )
42- elem_found = False
43- for year in years :
44- if who in mp_ids [year ]:
45- elem_found = True
46-
47- if not elem_found :
48- found = False
49- false_whos .append (who )
50-
51- # Check for dead or child speakers
52- dead_whos = []
53- child_whos = []
54- mp_doa = mp_db [['id' , 'born' , 'dead' ]].drop_duplicates ().reset_index (drop = True )
55- mp_doa ['born' ] = mp_doa ['born' ].fillna ('0000' )
56- mp_doa ['dead' ] = mp_doa ['dead' ].fillna ('9999' )
57-
58- fronts = root .findall (".//{http://www.tei-c.org/ns/1.0}front" )
59- heads = fronts [0 ].findall (".//{http://www.tei-c.org/ns/1.0}head" )
60- for who in whos :
61- mp = mp_doa .loc [mp_doa ['id' ] == who ]
62-
63- warning_text = f"Speaker { who } not found in db. Protocol { heads [0 ].text } "
64- self .assertGreaterEqual (len (mp ), 1 , warning_text )
65-
66- born = min (mp ['born' ].apply (lambda x : int (x [:4 ])))
67- dead = max (mp ['dead' ].apply (lambda x : int (x [:4 ])))
68- if max (years ) > dead :
69- dead_whos .append (who )
70- if max (years ) < born + 15 :
71- child_whos .append (who )
72-
73- return found , false_whos , dead_whos , child_whos
74-
75- # new
124+
76125 folder = "data"
77126 * _ , mp_db , minister_db , speaker_db = load_metadata ()
127+
78128 mp_db = pd .concat ([mp_db , minister_db , speaker_db ])
129+ mp_db ["start" ] = mp_db ["start" ].map (parse_date_start )
130+ mp_db ["end" ] = mp_db ["end" ].map (parse_date_end )
131+
132+ mp_doa = mp_db [["id" , "born" , "dead" ]].drop_duplicates ().reset_index (drop = True )
133+ mp_doa ["born" ] = pd .to_datetime (mp_doa ["born" ], errors = "coerce" )
134+ mp_doa ["dead" ] = pd .to_datetime (mp_doa ["dead" ], errors = "coerce" )
135+ mp_doa = mp_doa .rename (columns = {"id" : "person_id" })
79136
80- mp_ids = {}
137+ records = []
81138
82- failed_protocols = []
139+ # Find the dates, people and intros in protocols
83140 for outfolder in progressbar .progressbar (list (Path (folder ).glob ("*/" ))):
84141 for protocol_path in outfolder .glob ("*.xml" ):
142+
85143 protocol_id = protocol_path .stem
86- path_str = str (protocol_path .resolve ())
87- root = etree .parse (path_str , parser ).getroot ()
88- found , false_whos , dead_whos , child_whos = test_one_protocol (root , mp_ids , mp_db )
89- if not found :
90- failed_protocols .append (protocol_id + " (" + false_whos [0 ] + ")" )
144+ root , ns = parse_tei (str (protocol_path ))
145+
146+ match_error , dates = get_doc_dates (root )
147+ protocol_dates = [pd .to_datetime (d , errors = "coerce" ) for d in dates if d ]
148+
149+ if not protocol_dates :
150+ continue
151+
152+ last_speaker_intro = ""
153+
154+ for body in root .findall (f".//{ ns ['tei_ns' ]} body" ):
155+ for div in body .findall (f"{ ns ['tei_ns' ]} div" ):
156+ for elem in div :
157+
158+ if (
159+ elem .tag == f"{ ns ['tei_ns' ]} note"
160+ and elem .attrib .get ("type" ) == "speaker"
161+ ):
162+ if elem .text :
163+ last_speaker_intro = elem .text .strip ()
164+
165+ elif elem .tag == f"{ ns ['tei_ns' ]} u" :
166+
167+ who = elem .attrib .get ("who" )
168+
169+ if not who or who == "unknown" :
170+ continue
171+
172+ for pdate in protocol_dates :
173+ records .append ({"protocol_id" : protocol_id , "person_id" : who , "protocol_date" : pdate , "speaker_intro_text" : last_speaker_intro })
174+
175+ df = pd .DataFrame (records )
176+
177+ relevant_ids = df ["person_id" ].unique ()
178+ mp_subset = mp_db [mp_db ["id" ].isin (relevant_ids )][["id" , "start" , "end" ]]
179+
180+ df_merged = df .merge (mp_subset , how = "left" , left_on = "person_id" , right_on = "id" )
181+ df_merged ["valid_mandate" ] = ((df_merged ["protocol_date" ] >= df_merged ["start" ]) & (df_merged ["protocol_date" ] <= df_merged ["end" ]))
182+
183+ # For each protocol/person, check if ANY protocol_date matches a mandate
184+ mandate_check = (df_merged .groupby (["protocol_id" ,"person_id" ])["valid_mandate" ].any ().reset_index ())
185+
186+ df_fail = mandate_check [~ mandate_check ["valid_mandate" ]].drop (columns = "valid_mandate" )
187+
188+ df_fail_dates = (
189+ df .groupby (["protocol_id" , "person_id" ])["protocol_date" ]
190+ .apply (
191+ lambda x : ";" .join (
192+ sorted (x .dt .strftime ("%Y-%m-%d" ).unique ())
193+ )
194+ )
195+ .reset_index (name = "protocol_dates" )
196+ )
197+
198+ df_fail = df_fail .merge (df_fail_dates , on = ["protocol_id" , "person_id" ])
199+
200+ df_intro = df [["protocol_id" , "person_id" , "speaker_intro_text" ]].drop_duplicates ()
201+
202+ df_fail = df_fail .merge (df_intro , on = ["protocol_id" , "person_id" ], how = "left" )
203+
204+ # Child-dead checks
205+ df = df .merge (mp_doa , on = "person_id" , how = "left" )
206+
207+ df_dead = df [(df ["dead" ].notna ()) & (df ["protocol_date" ] > df ["dead" ])].copy ()
208+ df_child = df [(df ["born" ].notna ()) & (df ["protocol_date" ] < df ["born" ] + pd .DateOffset (years = 15 ))].copy ()
209+
210+ df_dead = aggregate_dates (df_dead )
211+ df_child = aggregate_dates (df_child )
212+
213+ df_fail = df_fail .drop_duplicates (subset = ["protocol_id" ,"person_id" ])
214+ df_dead = df_dead .drop_duplicates (subset = ["protocol_id" ,"person_id" ])
215+ df_child = df_child .drop_duplicates (subset = ["protocol_id" ,"person_id" ])
216+
217+ # Write results
218+ results_dir = "test/results"
219+ os .makedirs (results_dir , exist_ok = True )
220+
221+ df_fail .to_csv (f"{ results_dir } /missing-persons.csv" , index = False )
222+ df_dead .to_csv (f"{ results_dir } /dead-mps.csv" , index = False )
223+ df_child .to_csv (f"{ results_dir } /child-mps.csv" , index = False )
224+
225+ # Check changes
226+ failures = []
227+
228+ baseline_dir = "test/data/mp"
91229
92- print ("Protocols with inactive MPs tagged as speakers:" , ", " .join (failed_protocols ))
93- print ("Dead MPs tagged as speakers:" , ", " .join (dead_whos ))
94- print ("Children MPs tagged as speakers:" , ", " .join (child_whos ))
230+ try :
231+ logger .info (f"=== Checking { len (df_fail )} errors for MP presence in protocol vs MP mandate periods ===" )
232+ assert_ci (f"{ baseline_dir } /baseline-missing-persons.csv" , df_fail )
233+ logger .info ("" )
234+ except AssertionError as e :
235+ failures .append (str (e ))
95236
96- self .assertEqual (len (failed_protocols ), 0 )
237+ try :
238+ logger .info (f"=== Checking { len (df_dead )} errors for MPs appearing in protocol after death ===" )
239+ assert_ci (f"{ baseline_dir } /baseline-dead-mps.csv" , df_dead )
240+ logger .info ("" )
241+ except AssertionError as e :
242+ failures .append (str (e ))
97243
244+ try :
245+ logger .info (f"=== Checking { len (df_child )} errors for MPs appearing in protocol before age 15 ===" )
246+ assert_ci (f"{ baseline_dir } /baseline-child-mps.csv" , df_child )
247+ except AssertionError as e :
248+ failures .append (str (e ))
98249
250+ if failures :
251+ raise AssertionError ("\n " .join (failures ))
99252
100253
101- if __name__ == '__main__' :
102- # begin the unittest.main()
103- unittest .main ()
254+ if __name__ == "__main__" :
255+ unittest .main ()
0 commit comments