diff --git a/src/aces/query.py b/src/aces/query.py index 9f85876..d96f248 100644 --- a/src/aces/query.py +++ b/src/aces/query.py @@ -4,6 +4,7 @@ """ import logging +from datetime import timedelta import polars as pl from bigtree import preorder_iter @@ -11,6 +12,7 @@ from .config import TaskExtractorConfig from .constraints import check_constraints, check_static_variables from .extract_subtree import extract_subtree +from .types import TemporalWindowBounds from .utils import log_tree logger = logging.getLogger(__name__) @@ -78,6 +80,49 @@ def query(cfg: TaskExtractorConfig, predicates_df: pl.DataFrame) -> pl.DataFrame True >>> "label" in result.columns True + >>> # Test backward window (start: end - 1 day, end: trigger): index_timestamp should be the + >>> # chronological end/start of the window, not the opposite end due to reversed summary timestamps. + >>> cfg_bwd = TaskExtractorConfig( + ... predicates={}, + ... trigger=EventConfig("_ANY_EVENT"), + ... windows={ + ... "input": WindowConfig("end - 1 day", "trigger", True, True, index_timestamp="end"), + ... }, + ... ) + >>> predicates_bwd = pl.DataFrame({ + ... "subject_id": [1], + ... "timestamp": [datetime(2010, 6, 20)], + ... "_ANY_EVENT": [True], + ... }) + >>> with caplog.at_level(logging.INFO): + ... result_bwd = query(cfg_bwd, predicates_bwd) + >>> result_bwd.select("subject_id", "index_timestamp") + shape: (1, 2) + ┌────────────┬─────────────────────┐ + │ subject_id ┆ index_timestamp │ + │ --- ┆ --- │ + │ i64 ┆ datetime[μs] │ + ╞════════════╪═════════════════════╡ + │ 1 ┆ 2010-06-20 00:00:00 │ + └────────────┴─────────────────────┘ + >>> cfg_bwd_start = TaskExtractorConfig( + ... predicates={}, + ... trigger=EventConfig("_ANY_EVENT"), + ... windows={ + ... "input": WindowConfig("end - 1 day", "trigger", True, True, index_timestamp="start"), + ... }, + ... ) + >>> with caplog.at_level(logging.INFO): + ... result_bwd_start = query(cfg_bwd_start, predicates_bwd) + >>> result_bwd_start.select("subject_id", "index_timestamp") + shape: (1, 2) + ┌────────────┬─────────────────────┐ + │ subject_id ┆ index_timestamp │ + │ --- ┆ --- │ + │ i64 ┆ datetime[μs] │ + ╞════════════╪═════════════════════╡ + │ 1 ┆ 2010-06-19 00:00:00 │ + └────────────┴─────────────────────┘ >>> cfg = TaskExtractorConfig( ... predicates={"A": PlainPredicateConfig("A", static=True)}, ... trigger=EventConfig("_ANY_EVENT"), @@ -184,12 +229,23 @@ def query(cfg: TaskExtractorConfig, predicates_df: pl.DataFrame) -> pl.DataFrame f"Setting index timestamp as '{cfg.windows[cfg.index_timestamp_window].index_timestamp}' " f"of window '{cfg.index_timestamp_window}'..." ) - index_timestamp_col = ( - "end" if cfg.windows[cfg.index_timestamp_window].root_node == "start" else "start" - ) + window_cfg = cfg.windows[cfg.index_timestamp_window] + index_timestamp_col = "end" if window_cfg.root_node == "start" else "start" + # When start_endpoint_expr is a TemporalWindowBounds with negative window_size + # (e.g., start: end - 24h), aggregate_temporal_window places the anchor time in + # timestamp_at_start and the earlier window start in timestamp_at_end (chronologically + # reversed). We must access the opposite field to get the correct chronological timestamp. + if ( + window_cfg.root_node == "end" + and isinstance(window_cfg.start_endpoint_expr, TemporalWindowBounds) + and window_cfg.start_endpoint_expr.window_size < timedelta(0) + ): + timestamp_field = "start" if window_cfg.index_timestamp == "end" else "end" + else: + timestamp_field = window_cfg.index_timestamp result = result.with_columns( pl.col(f"{cfg.index_timestamp_window}.{index_timestamp_col}_summary") - .struct.field(f"timestamp_at_{cfg.windows[cfg.index_timestamp_window].index_timestamp}") + .struct.field(f"timestamp_at_{timestamp_field}") .alias("index_timestamp") ) to_return_cols.insert(1, "index_timestamp")