Skip to content

Commit 43d6f6e

Browse files
authored
Improving Test Coverage (#171)
* Test for end of record as a start event * Test parsing time from None * Tests for query edge cases * Test for missing static predicates * Ignore __init__ * Test label reference * Tests for error messages for loading and initializing functions * Test for missing references in nested derived predicates * Explicitly include __main__.py * Fix --cov-include * Erase Codecov first * Reset by deleting __main__.py * Readd __main__.py * Reset tests * Re-add all tests * Test for invalid args * Ignore logging * Tests for tuple and invalid endpoint_expr * Remove unnecessary imports * Tests for right and left inclusive and removed extra else from coverage * Resolve lazyframe warnings * Test missing right and left inclusive options * Add codecov config * No cover in __main__ * Change __main__ to run * Fix numpy version and run.py tests * No cover for main script * No cover for main execution either * ESGPT import error patch test
1 parent 0b68221 commit 43d6f6e

File tree

12 files changed

+258
-34
lines changed

12 files changed

+258
-34
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ classifiers = [
1919
]
2020
dependencies = [
2121
"polars >= 1.0.0, <= 1.17.1",
22+
"numpy < 1.29.0",
2223
"bigtree == 0.18.*",
2324
"ruamel.yaml == 0.18.*",
2425
"hydra-core == 1.3.*",
@@ -37,7 +38,7 @@ your_package = ["*.pyi", "py.typed"]
3738
[tool.setuptools_scm]
3839

3940
[project.scripts]
40-
aces-cli = "aces.__main__:main"
41+
aces-cli = "aces.run:main"
4142
expand_shards = "aces.expand_shards:main"
4243

4344
[project.optional-dependencies]

src/aces/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,5 @@
77
__package_name__ = "es-aces"
88
try:
99
__version__ = version(__package_name__)
10-
except PackageNotFoundError:
10+
except PackageNotFoundError: # pragma: no cover
1111
__version__ = "unknown"

src/aces/aggregate.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,22 @@ def aggregate_event_bound_window(
370370
│ 2 ┆ 1989-12-08 16:22:00 ┆ null ┆ null ┆ 0 ┆ 0 ┆ 0 │
371371
│ 2 ┆ 1989-12-10 03:07:00 ┆ null ┆ null ┆ 0 ┆ 0 ┆ 0 │
372372
└────────────┴─────────────────────┴─────────────────────┴─────────────────────┴──────┴──────┴──────┘
373+
>>> aggregate_event_bound_window(df, (True, "is_C", True, timedelta(days=3)))
374+
shape: (8, 7)
375+
┌────────────┬─────────────────────┬─────────────────────┬─────────────────────┬──────┬──────┬──────┐
376+
│ subject_id ┆ timestamp ┆ timestamp_at_start ┆ timestamp_at_end ┆ is_A ┆ is_B ┆ is_C │
377+
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
378+
│ i64 ┆ datetime[μs] ┆ datetime[μs] ┆ datetime[μs] ┆ i64 ┆ i64 ┆ i64 │
379+
╞════════════╪═════════════════════╪═════════════════════╪═════════════════════╪══════╪══════╪══════╡
380+
│ 1 ┆ 1989-12-01 12:03:00 ┆ null ┆ null ┆ 0 ┆ 0 ┆ 0 │
381+
│ 1 ┆ 1989-12-03 13:14:00 ┆ null ┆ null ┆ 0 ┆ 0 ┆ 0 │
382+
│ 1 ┆ 1989-12-05 15:17:00 ┆ null ┆ null ┆ 0 ┆ 0 ┆ 0 │
383+
│ 2 ┆ 1989-12-02 12:03:00 ┆ 1989-12-05 12:03:00 ┆ 1989-12-06 15:17:00 ┆ 1 ┆ 1 ┆ 1 │
384+
│ 2 ┆ 1989-12-04 13:14:00 ┆ 1989-12-07 13:14:00 ┆ 1989-12-10 03:07:00 ┆ 0 ┆ 2 ┆ 1 │
385+
│ 2 ┆ 1989-12-06 15:17:00 ┆ 1989-12-09 15:17:00 ┆ 1989-12-10 03:07:00 ┆ 0 ┆ 1 ┆ 1 │
386+
│ 2 ┆ 1989-12-08 16:22:00 ┆ null ┆ null ┆ 0 ┆ 0 ┆ 0 │
387+
│ 2 ┆ 1989-12-10 03:07:00 ┆ null ┆ null ┆ 0 ┆ 0 ┆ 0 │
388+
└────────────┴─────────────────────┴─────────────────────┴─────────────────────┴──────┴──────┴──────┘
373389
"""
374390
if not isinstance(endpoint_expr, ToEventWindowBounds):
375391
endpoint_expr = ToEventWindowBounds(*endpoint_expr)
@@ -845,6 +861,24 @@ def boolean_expr_bound_sum(
845861
│ 2 ┆ 1989-12-08 16:22:00 ┆ 1989-12-05 16:22:00 ┆ 1989-12-10 03:07:00 ┆ 1 ┆ 3 ┆ 2 │
846862
│ 2 ┆ 1989-12-10 03:07:00 ┆ 1989-12-07 03:07:00 ┆ 1989-12-10 03:07:00 ┆ 0 ┆ 2 ┆ 1 │
847863
└────────────┴─────────────────────┴─────────────────────┴─────────────────────┴──────┴──────┴──────┘
864+
865+
>>> boolean_expr_bound_sum(df, pl.col("idx").is_in([1, 4, 7]), "invalid_mode", "right",
866+
... offset = timedelta(days=-3))
867+
Traceback (most recent call last):
868+
...
869+
ValueError: Mode 'invalid_mode' invalid!
870+
>>> boolean_expr_bound_sum(df, pl.col("idx").is_in([1, 4, 7]), "row_to_bound", "invalid_closed",
871+
... offset = timedelta(days=-3))
872+
Traceback (most recent call last):
873+
...
874+
ValueError: Closed 'invalid_closed' invalid!
875+
876+
>>> boolean_expr_bound_sum(df, pl.col("idx").is_in([1, 4, 7]), mode="row_to_bound",
877+
... closed="right", offset=timedelta(days=1)).columns
878+
['subject_id', 'timestamp', 'timestamp_at_start', 'timestamp_at_end', 'idx', 'is_A', 'is_B', 'is_C']
879+
>>> boolean_expr_bound_sum(df, pl.col("idx").is_in([1, 4, 7]), mode="row_to_bound",
880+
... closed="left", offset=timedelta(days=-1)).columns
881+
['subject_id', 'timestamp', 'timestamp_at_start', 'timestamp_at_end', 'idx', 'is_A', 'is_B', 'is_C']
848882
"""
849883
if mode not in ("bound_to_row", "row_to_bound"):
850884
raise ValueError(f"Mode '{mode}' invalid!")
@@ -1010,7 +1044,8 @@ def agg_offset_fn(c: str) -> pl.Expr:
10101044
def agg_offset_fn(c: str) -> pl.Expr:
10111045
return pl.col(c) + pl.col(f"{c}_in_offset_period")
10121046

1013-
else:
1047+
# Might not need as mode is already checked above (line 888)
1048+
else: # pragma: no cover
10141049
raise ValueError(f"Mode '{mode}' and offset '{offset}' invalid!")
10151050

10161051
return with_at_boundary_events.join(

src/aces/config.py

Lines changed: 91 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ def ESGPT_eval_expr(self, values_column: str | None = None) -> pl.Expr:
205205
>>> expr = PlainPredicateConfig("BP//systole", other_cols={"chamber": "atrial"}).ESGPT_eval_expr()
206206
>>> print(expr) # doctest: +NORMALIZE_WHITESPACE
207207
[(col("BP")) == (String(systole))].all_horizontal([[(col("chamber")) == (String(atrial))]])
208+
208209
>>> expr = PlainPredicateConfig("BP//systolic", value_min=120).ESGPT_eval_expr()
209210
Traceback (most recent call last):
210211
...
@@ -522,6 +523,7 @@ class WindowConfig:
522523
offset=datetime.timedelta(0))
523524
>>> target_window.root_node
524525
'end'
526+
525527
>>> invalid_window = WindowConfig(
526528
... start="gap.end gap.start",
527529
... end="start -> discharge_or_death",
@@ -1014,7 +1016,6 @@ class TaskExtractorConfig:
10141016
Traceback (most recent call last):
10151017
...
10161018
FileNotFoundError: Cannot load missing configuration file /foo/non_existent_file.yaml!
1017-
10181019
>>> import tempfile
10191020
>>> with tempfile.NamedTemporaryFile(mode="w", suffix=".txt") as f:
10201021
... config_path = Path(f.name)
@@ -1059,7 +1060,6 @@ class TaskExtractorConfig:
10591060
Traceback (most recent call last):
10601061
...
10611062
ValueError: Unrecognized keys in configuration file: 'foo, trigger'
1062-
10631063
>>> predicates = {"foo bar": PlainPredicateConfig("foo")}
10641064
>>> trigger = EventConfig("foo")
10651065
>>> config = TaskExtractorConfig(predicates=predicates, trigger=trigger, windows={})
@@ -1083,7 +1083,6 @@ class TaskExtractorConfig:
10831083
Traceback (most recent call last):
10841084
...
10851085
KeyError: "Missing 1 relationships: Derived predicate 'foobar' references undefined predicate 'bar'"
1086-
10871086
>>> predicates = {"foo": PlainPredicateConfig("foo")}
10881087
>>> trigger = EventConfig("foo")
10891088
>>> windows = {"foo bar": WindowConfig("gap.end", "start + 24h", True, True)}
@@ -1119,7 +1118,6 @@ class TaskExtractorConfig:
11191118
...
11201119
ValueError: Only the 'start'/'end' of one window can be used as the index timestamp, found
11211120
2 windows with index_timestamp: foo, bar
1122-
11231121
>>> predicates = {"foo": PlainPredicateConfig("foo")}
11241122
>>> trigger = EventConfig("bar")
11251123
>>> config = TaskExtractorConfig(predicates=predicates, trigger=trigger, windows={})
@@ -1255,7 +1253,8 @@ def load(
12551253
... "windows": {
12561254
... "start": {
12571255
... "start": None, "end": "trigger + 24h", "start_inclusive": True,
1258-
... "end_inclusive": True, "has": {"abnormal_labs": "(1, None)"},
1256+
... "end_inclusive": True, "label": "abnormal_labs",
1257+
... "has": {"abnormal_labs": "(1, None)"},
12591258
... }
12601259
... },
12611260
... }
@@ -1285,6 +1284,33 @@ def load(
12851284
...
12861285
ValueError: Predicate 'admission' is not defined correctly in the configuration file. Currently
12871286
defined as the string: invalid. Please refer to the documentation for the supported formats.
1287+
>>> predicates_dict = {
1288+
... "predicates": {'adm': {"code": "admission"}},
1289+
... }
1290+
>>> with (tempfile.NamedTemporaryFile(mode="w", suffix=".yaml") as config_fp,
1291+
... tempfile.NamedTemporaryFile(mode="w", suffix=".yaml") as pred_fp):
1292+
... config_path = Path(config_fp.name)
1293+
... pred_path = Path(pred_fp.name)
1294+
... yaml.dump(no_predicates_config, config_fp)
1295+
... yaml.dump(predicates_dict, pred_fp)
1296+
... cfg = TaskExtractorConfig.load(config_path, pred_path) # doctest: +NORMALIZE_WHITESPACE
1297+
Traceback (most recent call last):
1298+
...
1299+
KeyError: "Something referenced predicate 'admission' that wasn't defined in the configuration."
1300+
>>> config_dict = {
1301+
... "predicates": {"A": {"code": "A"}, "B": {"code": "B"}, "A_or_B": {"expr": "or(A, B)"},
1302+
... "A_or_B_and_C": {"expr": "and(A_or_B, C)"}},
1303+
... "trigger": "_ANY_EVENT",
1304+
... "windows": {"start": {"start": None, "end": "trigger + 24h", "start_inclusive": True,
1305+
... "end_inclusive": True, "has": {"A_or_B_and_C": "(1, None)"}}},
1306+
... }
1307+
>>> with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml") as f:
1308+
... config_path = Path(f.name)
1309+
... yaml.dump(config_dict, f)
1310+
... cfg = TaskExtractorConfig.load(config_path) # doctest: +NORMALIZE_WHITESPACE
1311+
Traceback (most recent call last):
1312+
...
1313+
KeyError: "Predicate 'C' referenced in 'A_or_B_and_C' is not defined in the configuration."
12881314
"""
12891315
if isinstance(config_path, str):
12901316
config_path = Path(config_path)
@@ -1351,7 +1377,7 @@ def load(
13511377
all_predicates = {**final_predicates, **final_demographics}
13521378

13531379
logger.info("Parsing windows...")
1354-
if windows is None:
1380+
if windows is None: # pragma: no cover
13551381
windows = {}
13561382
logger.warning(
13571383
"No windows specified in configuration file. Extracting only matching trigger events."
@@ -1432,6 +1458,23 @@ def _initialize_predicates(self) -> None:
14321458
14331459
Raises:
14341460
ValueError: If the predicate name is not valid.
1461+
1462+
Examples:
1463+
>>> import networkx as nx
1464+
>>> TaskExtractorConfig(
1465+
... predicates={
1466+
... "A": DerivedPredicateConfig("and(A, B)"), # A depends on B
1467+
... "B": DerivedPredicateConfig("and(B, C)"), # B depends on C
1468+
... "C": DerivedPredicateConfig("and(A, C)"), # C depends on A (Cyclic dependency)
1469+
... },
1470+
... trigger=EventConfig("A"),
1471+
... windows={},
1472+
... ) # doctest: +NORMALIZE_WHITESPACE
1473+
Traceback (most recent call last):
1474+
...
1475+
ValueError: Predicate graph is not a directed acyclic graph!
1476+
Cycle found: [('A', 'A')]
1477+
Graph: None
14351478
"""
14361479

14371480
dag_relationships = []
@@ -1479,6 +1522,42 @@ def _initialize_windows(self) -> None:
14791522
14801523
Raises:
14811524
ValueError: If the window name is not valid.
1525+
1526+
Examples:
1527+
>>> TaskExtractorConfig( # doctest: +NORMALIZE_WHITESPACE
1528+
... predicates={"A": PlainPredicateConfig("A")},
1529+
... windows={
1530+
... "win1": WindowConfig(None, "trigger", True, False, has={"B": "(1, 0)"}) # B undefined
1531+
... },
1532+
... trigger=EventConfig("_ANY_EVENT"),
1533+
... ) # doctest: +NORMALIZE_WHITESPACE
1534+
Traceback (most recent call last):
1535+
...
1536+
KeyError: "Window 'win1' references undefined predicate 'B'.
1537+
Window predicates: B;
1538+
Defined predicates: A"
1539+
>>> TaskExtractorConfig(
1540+
... predicates={"A": PlainPredicateConfig("A")},
1541+
... windows={
1542+
... "win1": WindowConfig(None, "event_not_trigger", True, False)
1543+
... },
1544+
... trigger=EventConfig("_ANY_EVENT"),
1545+
... ) # doctest: +NORMALIZE_WHITESPACE
1546+
Traceback (most recent call last):
1547+
...
1548+
KeyError: "Window 'win1' references undefined trigger event
1549+
'event_not_trigger' -- must be trigger!"
1550+
>>> TaskExtractorConfig(
1551+
... predicates={"A": PlainPredicateConfig("A")},
1552+
... windows={
1553+
... "win1": WindowConfig("win2.end", "start -> A", True, False)
1554+
... },
1555+
... trigger=EventConfig("_ANY_EVENT"),
1556+
... ) # doctest: +NORMALIZE_WHITESPACE
1557+
Traceback (most recent call last):
1558+
...
1559+
KeyError: "Window 'win1' references undefined window 'win2' for event 'end'.
1560+
Allowed windows: win1"
14821561
"""
14831562

14841563
for name in self.windows:
@@ -1559,8 +1638,8 @@ def _initialize_windows(self) -> None:
15591638
for predicate in window.referenced_predicates - {ANY_EVENT_COLUMN}:
15601639
if predicate not in self.predicates:
15611640
raise KeyError(
1562-
f"Window '{name}' references undefined predicate '{predicate}'.\n"
1563-
f"Window predicates: {', '.join(window.referenced_predicates)}\n"
1641+
f"Window '{name}' references undefined predicate '{predicate}'. "
1642+
f"Window predicates: {', '.join(window.referenced_predicates)}; "
15641643
f"Defined predicates: {', '.join(self.predicates.keys())}"
15651644
)
15661645

@@ -1580,15 +1659,17 @@ def _initialize_windows(self) -> None:
15801659
f"Window '{name}' references undefined window '{referenced_window}' "
15811660
f"for event '{referenced_event}'. Allowed windows: {', '.join(self.windows.keys())}"
15821661
)
1583-
if referenced_event not in {"start", "end"}:
1662+
# Might not be needed as valid window event references are already checked (line 660)
1663+
if referenced_event not in {"start", "end"}: # pragma: no cover
15841664
raise KeyError(
15851665
f"Window '{name}' references undefined event '{referenced_event}' "
15861666
f"for window '{referenced_window}'. Allowed events: 'start', 'end'"
15871667
)
15881668

15891669
parent_node = f"{referenced_window}.{referenced_event}"
15901670
window_nodes[f"{name}.{window.root_node}"].parent = window_nodes[parent_node]
1591-
else:
1671+
# Might not be needed as valid window event references are already checked (line 660)
1672+
else: # pragma: no cover
15921673
raise ValueError(
15931674
f"Window '{name}' references invalid event '{window.referenced_event}' "
15941675
"must be of length 1 or 2."

src/aces/constraints.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,10 @@ def check_static_variables(patient_demographics: list[str], predicates_df: pl.Da
166166
│ 1 ┆ 1989-12-02 12:03:00 ┆ 1 ┆ 0 ┆ 1 │
167167
│ 1 ┆ 1989-12-06 11:00:00 ┆ 0 ┆ 0 ┆ 0 │
168168
└────────────┴─────────────────────┴──────┴──────┴──────┘
169+
>>> check_static_variables(['female'], predicates_df)
170+
Traceback (most recent call last):
171+
...
172+
ValueError: Static predicate 'female' not found in the predicates dataframe.
169173
"""
170174
for demographic in patient_demographics:
171175
if demographic not in predicates_df.columns:

src/aces/expand_shards.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@ def expand_shards(*shards: str) -> str:
4848
... result = expand_shards(tmpdirname)
4949
... sorted(result.split(","))
5050
['1', '3', 'evens/0/file_0', 'evens/0/file_2']
51+
52+
>>> expand_shards("train.invalid")
53+
Traceback (most recent call last):
54+
...
55+
ValueError: Invalid shard format: train.invalid
5156
"""
5257

5358
result = []
@@ -71,9 +76,9 @@ def expand_shards(*shards: str) -> str:
7176
return ",".join(result)
7277

7378

74-
def main() -> None:
79+
def main() -> None: # pragma: no cover
7580
print(expand_shards(*sys.argv[1:]))
7681

7782

78-
if __name__ == "__main__":
83+
if __name__ == "__main__": # pragma: no cover
7984
main()

0 commit comments

Comments
 (0)