Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions lib/iris/src/iris/cluster/controller/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -1383,6 +1383,10 @@ def _run_prune_loop(self, stop_event: threading.Event) -> None:
self._transitions.prune_task_resource_history()
except Exception:
logger.exception("Task resource history cleanup failed")
try:
self._transitions.prune_task_stats_history()
except Exception:
logger.exception("Task stats history cleanup failed")

if wal_checkpoint_limiter.should_run():
try:
Expand Down
23 changes: 23 additions & 0 deletions lib/iris/src/iris/cluster/controller/migrations/0038_task_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright The Marin Authors
# SPDX-License-Identifier: Apache-2.0

import sqlite3


def migrate(conn: sqlite3.Connection) -> None:
conn.execute(
"""
CREATE TABLE IF NOT EXISTS task_stats_history (
id INTEGER PRIMARY KEY AUTOINCREMENT,
task_id TEXT NOT NULL REFERENCES tasks(task_id) ON DELETE CASCADE,
items_processed INTEGER NOT NULL DEFAULT 0,
bytes_processed INTEGER NOT NULL DEFAULT 0,
timestamp_ms INTEGER NOT NULL
)
"""
)
conn.execute("CREATE INDEX IF NOT EXISTS idx_task_stats_history_task" " ON task_stats_history(task_id, id DESC)")

existing = {row[1] for row in conn.execute("PRAGMA table_info(tasks)").fetchall()}
if "status_message" not in existing:
conn.execute("ALTER TABLE tasks ADD COLUMN status_message TEXT NOT NULL DEFAULT ''")
24 changes: 24 additions & 0 deletions lib/iris/src/iris/cluster/controller/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,8 @@ def generate_full_ddl(tables: Sequence[Table]) -> str:
default=None,
),
Column("current_worker_address", "TEXT", "", python_type=str | None, decoder=_nullable(str), default=None),
# Migration 0039
Column("status_message", "TEXT", "NOT NULL DEFAULT ''", python_type=str, decoder=str, default=""),
),
table_constraints=("UNIQUE(job_id, task_index)",),
indexes=(
Expand Down Expand Up @@ -1045,6 +1047,25 @@ def generate_full_ddl(tables: Sequence[Table]) -> str:
),
)

TASK_STATS_HISTORY = Table(
"task_stats_history",
"tsh",
columns=(
Column("id", "INTEGER", "PRIMARY KEY AUTOINCREMENT"),
Column(
"task_id",
"TEXT",
"NOT NULL REFERENCES tasks(task_id) ON DELETE CASCADE",
python_type=JobName,
decoder=JobName.from_wire,
),
Column("items_processed", "INTEGER", "NOT NULL DEFAULT 0"),
Column("bytes_processed", "INTEGER", "NOT NULL DEFAULT 0"),
Column("timestamp_ms", "INTEGER", "NOT NULL", python_type=Timestamp, decoder=decode_timestamp_ms),
),
indexes=("CREATE INDEX IF NOT EXISTS idx_task_stats_history_task" " ON task_stats_history(task_id, id DESC)",),
)

ENDPOINTS = Table(
"endpoints",
"e",
Expand Down Expand Up @@ -1312,6 +1333,7 @@ def generate_full_ddl(tables: Sequence[Table]) -> str:
WORKER_TASK_HISTORY,
WORKER_RESOURCE_HISTORY,
TASK_RESOURCE_HISTORY,
TASK_STATS_HISTORY,
ENDPOINTS,
DISPATCH_QUEUE,
SCALING_GROUPS,
Expand Down Expand Up @@ -1474,6 +1496,7 @@ class TaskDetailRow:
current_worker_id: WorkerId | None
current_worker_address: str | None
container_id: str | None = None
status_message: str = ""
attempts: tuple = dataclasses.field(default_factory=tuple)


Expand Down Expand Up @@ -1806,6 +1829,7 @@ def _job_columns(*names: str) -> tuple[tuple[Column, ...], tuple[str, ...]]:
"current_worker_id",
"current_worker_address",
"container_id",
"status_message",
extra_fields=(ExtraField("attempts", tuple, default_factory=tuple),),
row_cls=TaskDetailRow,
)
Expand Down
86 changes: 86 additions & 0 deletions lib/iris/src/iris/cluster/controller/transitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,17 @@ class ReservationClaim:
runs every 10 min so total wall-time is irrelevant — bounding worst-case
writer hold is what matters for concurrent RPCs."""

TASK_STATS_HISTORY_RETENTION = 50
"""Maximum task_stats_history rows retained per task_id.
Logarithmic downsampling triggers at 2x this value."""

TASK_STATS_HISTORY_TERMINAL_TTL = Duration.from_hours(1)
"""After a task reaches a terminal state, its stats history is fully evicted
this long after the finish timestamp."""

TASK_STATS_HISTORY_DELETE_CHUNK = 1000
"""Maximum task_ids per DELETE in prune_task_stats_history."""

DIRECT_PROVIDER_PROMOTION_RATE = 128
"""Token bucket capacity for task promotion (pods per minute).

Expand Down Expand Up @@ -2670,6 +2681,69 @@ def prune_task_resource_history(self) -> int:
logger.info("Pruned %d task_resource_history rows (log downsampling)", total_deleted)
return evicted_terminal + total_deleted

def prune_task_stats_history(self) -> int:
"""Two-pass prune for task_stats_history, mirroring prune_task_resource_history.

1. Evict all history for tasks that have been in a terminal state
longer than TASK_STATS_HISTORY_TERMINAL_TTL.
2. Logarithmic downsampling for anything that remains: when a task_id
exceeds 2*N rows, thin the older half by deleting every other row.
"""
now_ms = Timestamp.now().epoch_ms()
ttl_cutoff_ms = now_ms - TASK_STATS_HISTORY_TERMINAL_TTL.to_ms()
terminal_placeholders = ",".join("?" for _ in TERMINAL_TASK_STATES)

with self._db.read_snapshot() as snap:
terminal_ids = [
str(r["task_id"])
for r in snap.fetchall(
f"SELECT task_id FROM tasks "
f"WHERE state IN ({terminal_placeholders}) "
f"AND finished_at_ms IS NOT NULL AND finished_at_ms < ?",
(*TERMINAL_TASK_STATES, ttl_cutoff_ms),
)
]

evicted_terminal = 0
for chunk_start in range(0, len(terminal_ids), TASK_STATS_HISTORY_DELETE_CHUNK):
chunk = terminal_ids[chunk_start : chunk_start + TASK_STATS_HISTORY_DELETE_CHUNK]
ph = ",".join("?" * len(chunk))
with self._db.transaction() as cur:
cur.execute(f"DELETE FROM task_stats_history WHERE task_id IN ({ph})", tuple(chunk))
evicted_terminal += cur.rowcount

threshold = TASK_STATS_HISTORY_RETENTION * 2
with self._db.transaction() as cur:
overflows = cur.execute(
"SELECT task_id, COUNT(*) as cnt FROM task_stats_history GROUP BY task_id HAVING cnt > ?",
(threshold,),
).fetchall()
ids_to_delete: list[int] = []
for row in overflows:
tid = row["task_id"]
all_ids = [
r["id"]
for r in cur.execute(
"SELECT id FROM task_stats_history WHERE task_id = ? ORDER BY id ASC",
(tid,),
).fetchall()
]
older = all_ids[: len(all_ids) - TASK_STATS_HISTORY_RETENTION]
ids_to_delete.extend(older[1::2])

total_deleted = 0
for chunk_start in range(0, len(ids_to_delete), 900):
chunk = ids_to_delete[chunk_start : chunk_start + 900]
ph = ",".join("?" * len(chunk))
cur.execute(f"DELETE FROM task_stats_history WHERE id IN ({ph})", tuple(chunk))
total_deleted += cur.rowcount

if evicted_terminal > 0:
logger.info("Evicted %d task_stats_history rows (terminal TTL)", evicted_terminal)
if total_deleted > 0:
logger.info("Pruned %d task_stats_history rows (log downsampling)", total_deleted)
return evicted_terminal + total_deleted

def _batch_delete(
self,
sql: str,
Expand Down Expand Up @@ -2925,6 +2999,18 @@ def record_task_stats(self, task_id: JobName, items_processed: int, bytes_proces
bytes_processed,
status,
)
now_ms = int(time.time() * 1000)
with self._db.transaction() as cur:
cur.execute(
"INSERT INTO task_stats_history"
" (task_id, items_processed, bytes_processed, timestamp_ms)"
" VALUES (?, ?, ?, ?)",
(task_id.to_wire(), items_processed, bytes_processed, now_ms),
)
cur.execute(
"UPDATE tasks SET status_message = ? WHERE task_id = ?",
(status, task_id.to_wire()),
)

# --- Endpoint Management ---

Expand Down
Loading
Loading