Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 60 additions & 4 deletions src/aces/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
"""

import logging
from datetime import timedelta

import polars as pl
from bigtree import preorder_iter

from .config import TaskExtractorConfig
from .constraints import check_constraints, check_static_variables
from .extract_subtree import extract_subtree
from .types import TemporalWindowBounds
from .utils import log_tree

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -78,6 +80,49 @@ def query(cfg: TaskExtractorConfig, predicates_df: pl.DataFrame) -> pl.DataFrame
True
>>> "label" in result.columns
True
>>> # Test backward window (start: end - 1 day, end: trigger): index_timestamp should be the
>>> # chronological end/start of the window, not the opposite end due to reversed summary timestamps.
>>> cfg_bwd = TaskExtractorConfig(
... predicates={},
... trigger=EventConfig("_ANY_EVENT"),
... windows={
... "input": WindowConfig("end - 1 day", "trigger", True, True, index_timestamp="end"),
... },
... )
>>> predicates_bwd = pl.DataFrame({
... "subject_id": [1],
... "timestamp": [datetime(2010, 6, 20)],
... "_ANY_EVENT": [True],
... })
>>> with caplog.at_level(logging.INFO):
... result_bwd = query(cfg_bwd, predicates_bwd)
>>> result_bwd.select("subject_id", "index_timestamp")
shape: (1, 2)
┌────────────┬─────────────────────┐
│ subject_id ┆ index_timestamp │
│ --- ┆ --- │
│ i64 ┆ datetime[μs] │
╞════════════╪═════════════════════╡
│ 1 ┆ 2010-06-20 00:00:00 │
└────────────┴─────────────────────┘
>>> cfg_bwd_start = TaskExtractorConfig(
... predicates={},
... trigger=EventConfig("_ANY_EVENT"),
... windows={
... "input": WindowConfig("end - 1 day", "trigger", True, True, index_timestamp="start"),
... },
... )
>>> with caplog.at_level(logging.INFO):
... result_bwd_start = query(cfg_bwd_start, predicates_bwd)
>>> result_bwd_start.select("subject_id", "index_timestamp")
shape: (1, 2)
┌────────────┬─────────────────────┐
│ subject_id ┆ index_timestamp │
│ --- ┆ --- │
│ i64 ┆ datetime[μs] │
╞════════════╪═════════════════════╡
│ 1 ┆ 2010-06-19 00:00:00 │
└────────────┴─────────────────────┘
>>> cfg = TaskExtractorConfig(
... predicates={"A": PlainPredicateConfig("A", static=True)},
... trigger=EventConfig("_ANY_EVENT"),
Expand Down Expand Up @@ -184,12 +229,23 @@ def query(cfg: TaskExtractorConfig, predicates_df: pl.DataFrame) -> pl.DataFrame
f"Setting index timestamp as '{cfg.windows[cfg.index_timestamp_window].index_timestamp}' "
f"of window '{cfg.index_timestamp_window}'..."
)
index_timestamp_col = (
"end" if cfg.windows[cfg.index_timestamp_window].root_node == "start" else "start"
)
window_cfg = cfg.windows[cfg.index_timestamp_window]
index_timestamp_col = "end" if window_cfg.root_node == "start" else "start"
# When start_endpoint_expr is a TemporalWindowBounds with negative window_size
# (e.g., start: end - 24h), aggregate_temporal_window places the anchor time in
# timestamp_at_start and the earlier window start in timestamp_at_end (chronologically
# reversed). We must access the opposite field to get the correct chronological timestamp.
if (
window_cfg.root_node == "end"
and isinstance(window_cfg.start_endpoint_expr, TemporalWindowBounds)
and window_cfg.start_endpoint_expr.window_size < timedelta(0)
):
timestamp_field = "start" if window_cfg.index_timestamp == "end" else "end"
else:
timestamp_field = window_cfg.index_timestamp
result = result.with_columns(
pl.col(f"{cfg.index_timestamp_window}.{index_timestamp_col}_summary")
.struct.field(f"timestamp_at_{cfg.windows[cfg.index_timestamp_window].index_timestamp}")
.struct.field(f"timestamp_at_{timestamp_field}")
.alias("index_timestamp")
)
to_return_cols.insert(1, "index_timestamp")
Expand Down