Skip to content

Commit e78e1dc

Browse files
authored
[iris] Evict terminal-task resource history past 1h TTL (#4850)
task_resource_history accumulated ~1M rows on marin prod, ~85% for tasks already in terminal states; the existing log-downsample prune only thinned, never evicted. Extends prune_task_resource_history with a TTL pass that drops history for tasks finished more than 1h ago. On the cached marin checkpoint this cut apply_heartbeats_batch baseline p95 from 5.6s to 158ms (~35x). Adds a compound-contention benchmark (benchmark_apply_contention) and fixes clone_db to preserve UNIQUE constraints so register_worker exercises the same UPSERT path as prod.
1 parent 7bf6c3b commit e78e1dc

File tree

3 files changed

+382
-11
lines changed

3 files changed

+382
-11
lines changed

lib/iris/scripts/benchmark_db_queries.py

Lines changed: 276 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -121,10 +121,20 @@ def clone_db(source: ControllerDB) -> ControllerDB:
121121
clone_path = clone_dir / ControllerDB.DB_FILENAME
122122
conn = sqlite3.connect(str(clone_path))
123123
conn.execute("ATTACH DATABASE ? AS src", (str(source.db_path),))
124-
# Copy schema + data for each table
125-
for table in _CLONE_TABLES:
126-
conn.execute(f"CREATE TABLE {table} AS SELECT * FROM src.{table}")
127-
# Copy indexes from source schema
124+
125+
# Use the source's real CREATE TABLE DDL — CREATE TABLE AS SELECT drops
126+
# UNIQUE/PRIMARY KEY/CHECK constraints, which breaks UPSERT paths like
127+
# register_worker's INSERT ... ON CONFLICT.
128+
clone_tables = set(_CLONE_TABLES)
129+
table_ddl = conn.execute("SELECT name, sql FROM src.sqlite_master WHERE type='table' AND sql IS NOT NULL").fetchall()
130+
for name, sql in table_ddl:
131+
if name not in clone_tables:
132+
continue
133+
conn.execute(sql)
134+
conn.execute(f"INSERT INTO {name} SELECT * FROM src.{name}")
135+
136+
# Copy indexes from source schema (skip autoindexes — those come from
137+
# UNIQUE/PK constraints already in the CREATE TABLE).
128138
rows = conn.execute("SELECT sql FROM src.sqlite_master WHERE type='index' AND sql IS NOT NULL").fetchall()
129139
for row in rows:
130140
try:
@@ -138,6 +148,7 @@ def clone_db(source: ControllerDB) -> ControllerDB:
138148
conn.execute(row[0])
139149
except sqlite3.OperationalError:
140150
pass
151+
conn.commit()
141152
conn.execute("DETACH DATABASE src")
142153
conn.execute("ANALYZE")
143154
conn.close()
@@ -1216,6 +1227,261 @@ def _burst_100_contended():
12161227
hb_thread.join(timeout=10.0)
12171228

12181229

1230+
def _build_heartbeat_requests(db: ControllerDB) -> list[HeartbeatApplyRequest]:
1231+
"""Build a heartbeat batch shaped like a live provider-sync round:
1232+
one HeartbeatApplyRequest per active worker, with one RUNNING
1233+
resource-usage update per task currently assigned to that worker.
1234+
"""
1235+
workers = healthy_active_workers_with_attributes(db)
1236+
active_states = tuple(ACTIVE_TASK_STATES)
1237+
snapshot_proto = job_pb2.WorkerResourceSnapshot()
1238+
usage = job_pb2.ResourceUsage(cpu_millicores=1000, memory_mb=1024)
1239+
requests: list[HeartbeatApplyRequest] = []
1240+
for w in workers:
1241+
wid = str(w.worker_id)
1242+
rows = db.fetchall(
1243+
"SELECT task_id, current_attempt_id FROM tasks " "WHERE current_worker_id = ? AND state IN (?, ?, ?)",
1244+
(wid, *active_states),
1245+
)
1246+
updates = [
1247+
TaskUpdate(
1248+
task_id=JobName.from_wire(str(r["task_id"])),
1249+
attempt_id=int(r["current_attempt_id"]),
1250+
new_state=job_pb2.TASK_STATE_RUNNING,
1251+
resource_usage=usage,
1252+
)
1253+
for r in rows
1254+
]
1255+
requests.append(
1256+
HeartbeatApplyRequest(
1257+
worker_id=WorkerId(wid),
1258+
worker_resource_snapshot=snapshot_proto,
1259+
updates=updates,
1260+
)
1261+
)
1262+
return requests
1263+
1264+
1265+
def _build_failure_batch(db: ControllerDB, n: int) -> list[tuple[DispatchBatch, str]]:
1266+
rows = db.fetchall(
1267+
"SELECT worker_id, address FROM workers WHERE active = 1 LIMIT ?",
1268+
(n,),
1269+
)
1270+
return [
1271+
(
1272+
DispatchBatch(
1273+
worker_id=WorkerId(str(r["worker_id"])),
1274+
worker_address=str(r["address"]) if r["address"] is not None else None,
1275+
running_tasks=[],
1276+
),
1277+
"benchmark: simulated provider-sync failure",
1278+
)
1279+
for r in rows
1280+
]
1281+
1282+
1283+
def _print_latency_distribution(name: str, latencies: list[float]) -> None:
1284+
if not latencies:
1285+
print(f" {name:60s} (no samples)")
1286+
return
1287+
latencies.sort()
1288+
p50 = latencies[len(latencies) // 2]
1289+
p95 = latencies[int(len(latencies) * 0.95)]
1290+
p99 = latencies[int(len(latencies) * 0.99)]
1291+
max_ms = latencies[-1]
1292+
_results.append((name, p50, p95, len(latencies)))
1293+
print(
1294+
f" {name:60s} n={len(latencies):3d} "
1295+
f"p50={p50:7.1f}ms p95={p95:8.1f}ms p99={p99:8.1f}ms max={max_ms:8.1f}ms"
1296+
)
1297+
1298+
1299+
def _run_apply_under_contention(
1300+
*,
1301+
name: str,
1302+
write_db: ControllerDB,
1303+
write_txns: ControllerTransitions,
1304+
heartbeat_requests: list[HeartbeatApplyRequest],
1305+
fail_threads: int = 0,
1306+
fail_n: int = 50,
1307+
fail_chunk: int = 50,
1308+
fail_interval_s: float = 2.0,
1309+
register_threads: int = 0,
1310+
register_burst: int = 100,
1311+
endpoint_threads: int = 0,
1312+
checkpoint_thread: bool = False,
1313+
synchronous_normal: bool = False,
1314+
duration_s: float = 8.0,
1315+
) -> None:
1316+
"""Run apply_heartbeats_batch repeatedly on a victim thread while
1317+
configurable write storms hammer the same clone DB. Report p50/p95/p99/max
1318+
of the victim's per-call latency.
1319+
"""
1320+
if synchronous_normal:
1321+
# PRAGMA synchronous can't be changed mid-connection once a tx has run,
1322+
# so issue it on a fresh raw connection to the clone file. It persists
1323+
# for that connection only; our ControllerDB connection is unaffected,
1324+
# which is the point — prod can't change synchronous mid-flight either.
1325+
_raw = sqlite3.connect(str(write_db.db_path))
1326+
_raw.execute("PRAGMA synchronous=NORMAL")
1327+
_raw.close()
1328+
1329+
endpoint_tasks_rows = write_db.fetchall(
1330+
"SELECT task_id FROM tasks WHERE state IN (1,2,3,9) AND current_attempt_id IS NOT NULL LIMIT 200"
1331+
)
1332+
endpoint_tasks = [JobName.from_wire(str(r["task_id"])) for r in endpoint_tasks_rows]
1333+
1334+
stop = threading.Event()
1335+
victim_latencies: list[float] = []
1336+
errors: list[BaseException] = []
1337+
1338+
def _victim():
1339+
try:
1340+
while not stop.is_set():
1341+
t0 = time.perf_counter()
1342+
write_txns.apply_heartbeats_batch(heartbeat_requests)
1343+
victim_latencies.append((time.perf_counter() - t0) * 1000)
1344+
except BaseException as e:
1345+
errors.append(e)
1346+
1347+
def _fail_storm():
1348+
try:
1349+
while not stop.is_set():
1350+
failures = _build_failure_batch(write_db, fail_n)
1351+
if failures:
1352+
write_txns.fail_heartbeats_batch(failures, force_remove=True, chunk_size=fail_chunk)
1353+
stop.wait(fail_interval_s)
1354+
except BaseException as e:
1355+
errors.append(e)
1356+
1357+
def _register_storm():
1358+
try:
1359+
meta = _build_sample_worker_metadata()
1360+
while not stop.is_set():
1361+
base = f"bench-contend-{uuid.uuid4().hex[:8]}"
1362+
for i in range(register_burst):
1363+
write_txns.register_worker(
1364+
worker_id=WorkerId(f"{base}-{i}"),
1365+
address=f"tcp://{base}-{i}:1234",
1366+
metadata=meta,
1367+
ts=Timestamp.now(),
1368+
slice_id="",
1369+
scale_group="bench",
1370+
)
1371+
if stop.is_set():
1372+
break
1373+
except BaseException as e:
1374+
errors.append(e)
1375+
1376+
def _endpoint_storm():
1377+
try:
1378+
i = 0
1379+
while not stop.is_set():
1380+
t = endpoint_tasks[i % len(endpoint_tasks)]
1381+
write_txns.add_endpoint(_make_endpoint(t))
1382+
i += 1
1383+
except BaseException as e:
1384+
errors.append(e)
1385+
1386+
def _checkpoint_loop():
1387+
try:
1388+
while not stop.is_set():
1389+
try:
1390+
write_db.execute("PRAGMA wal_checkpoint(TRUNCATE)")
1391+
except sqlite3.OperationalError:
1392+
pass
1393+
stop.wait(1.0)
1394+
except BaseException as e:
1395+
errors.append(e)
1396+
1397+
threads: list[threading.Thread] = [threading.Thread(target=_victim, name="victim")]
1398+
for _ in range(fail_threads):
1399+
threads.append(threading.Thread(target=_fail_storm, name="fail"))
1400+
for _ in range(register_threads):
1401+
threads.append(threading.Thread(target=_register_storm, name="register"))
1402+
for _ in range(endpoint_threads):
1403+
threads.append(threading.Thread(target=_endpoint_storm, name="endpoint"))
1404+
if checkpoint_thread:
1405+
threads.append(threading.Thread(target=_checkpoint_loop, name="checkpoint"))
1406+
1407+
for t in threads:
1408+
t.start()
1409+
time.sleep(duration_s)
1410+
stop.set()
1411+
for t in threads:
1412+
t.join(timeout=30.0)
1413+
1414+
if errors:
1415+
print(f" {name}: background thread error: {errors[0]!r}")
1416+
_print_latency_distribution(name, victim_latencies)
1417+
1418+
1419+
def benchmark_apply_contention(db: ControllerDB) -> None:
1420+
"""Reproduce the production 'apply results' multi-second tail by running
1421+
apply_heartbeats_batch as the victim under concurrent write storms.
1422+
"""
1423+
heartbeat_requests = _build_heartbeat_requests(db)
1424+
total_tasks = sum(len(r.updates) for r in heartbeat_requests)
1425+
print(f" (victim heartbeat batch: {len(heartbeat_requests)} workers, {total_tasks} tasks)")
1426+
1427+
if not heartbeat_requests:
1428+
print(" (skipped, no workers)")
1429+
return
1430+
1431+
scenarios = [
1432+
dict(name="apply @ baseline (no contention)"),
1433+
dict(name="apply + 1x fail_heartbeats_batch", fail_threads=1),
1434+
dict(name="apply + 1x register_worker burst", register_threads=1),
1435+
dict(name="apply + 1x add_endpoint storm", endpoint_threads=1),
1436+
dict(
1437+
name="apply + prod-mix (fail + register + endpoint)",
1438+
fail_threads=1,
1439+
register_threads=1,
1440+
endpoint_threads=1,
1441+
),
1442+
dict(
1443+
name="apply + heavy storm (2f/2r/2e, chunk=200, 0.5s)",
1444+
fail_threads=2,
1445+
fail_chunk=200,
1446+
fail_interval_s=0.5,
1447+
register_threads=2,
1448+
endpoint_threads=2,
1449+
),
1450+
dict(
1451+
name="apply + heavy + forced WAL checkpoints",
1452+
fail_threads=2,
1453+
fail_chunk=200,
1454+
fail_interval_s=0.5,
1455+
register_threads=2,
1456+
endpoint_threads=2,
1457+
checkpoint_thread=True,
1458+
),
1459+
dict(
1460+
name="apply + heavy + synchronous=NORMAL",
1461+
fail_threads=2,
1462+
fail_chunk=200,
1463+
fail_interval_s=0.5,
1464+
register_threads=2,
1465+
endpoint_threads=2,
1466+
synchronous_normal=True,
1467+
),
1468+
]
1469+
1470+
write_db = clone_db(db)
1471+
write_txns = ControllerTransitions(write_db)
1472+
try:
1473+
for scenario in scenarios:
1474+
_run_apply_under_contention(
1475+
write_db=write_db,
1476+
write_txns=write_txns,
1477+
heartbeat_requests=heartbeat_requests,
1478+
**scenario,
1479+
)
1480+
finally:
1481+
write_db.close()
1482+
shutil.rmtree(write_db._db_dir, ignore_errors=True)
1483+
1484+
12191485
def print_summary() -> None:
12201486
print("\n" + "=" * 80)
12211487
print(f" {'Query':50s} {'p50':>10s} {'p95':>10s} {'n':>5s}")
@@ -1263,7 +1529,7 @@ def _ensure_db(db_path: Path | None) -> Path:
12631529
@click.option(
12641530
"--only",
12651531
"only_group",
1266-
type=click.Choice(["scheduling", "dashboard", "heartbeat", "endpoints"]),
1532+
type=click.Choice(["scheduling", "dashboard", "heartbeat", "endpoints", "apply_contention"]),
12671533
help="Run only this group",
12681534
)
12691535
@click.option("--no-analyze", is_flag=True, help="Skip ANALYZE to test unoptimized query plans")
@@ -1309,6 +1575,11 @@ def main(db_path: Path | None, only_group: str | None, no_analyze: bool, fresh:
13091575
if only_group is None or only_group == "endpoints":
13101576
print("[endpoints]")
13111577
benchmark_endpoints(db)
1578+
print()
1579+
1580+
if only_group == "apply_contention":
1581+
print("[apply_contention]")
1582+
benchmark_apply_contention(db)
13121583

13131584
print_summary()
13141585
db.close()

lib/iris/src/iris/cluster/controller/transitions.py

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,17 @@ class ReservationClaim:
113113
"""Maximum task_resource_history rows retained per (task_id, attempt_id).
114114
Logarithmic downsampling triggers at 2x this value."""
115115

116+
TASK_RESOURCE_HISTORY_TERMINAL_TTL = Duration.from_hours(1)
117+
"""After a task reaches a terminal state, its resource history is fully
118+
evicted this long after the finish timestamp. Dashboards surface peak
119+
memory from tasks.peak_memory_mb once a task is done; retaining per-sample
120+
rows forever bloats the DB (~85% of task_resource_history on prod is for
121+
terminal tasks) and amplifies writer contention during heartbeat batches."""
122+
123+
TASK_RESOURCE_HISTORY_DELETE_CHUNK = 5000
124+
"""Maximum ids per DELETE in prune_task_resource_history — bounds how long
125+
the writer lock is held per chunk so other RPCs can interleave."""
126+
116127
DIRECT_PROVIDER_PROMOTION_RATE = 128
117128
"""Token bucket capacity for task promotion (pods per minute).
118129
@@ -2778,12 +2789,40 @@ def prune_worker_resource_history(self) -> int:
27782789
)
27792790

27802791
def prune_task_resource_history(self) -> int:
2781-
"""Logarithmic downsampling: when a (task, attempt) exceeds 2*N rows,
2782-
thin the older half by deleting every other row.
2783-
2784-
Over repeated compaction cycles older data becomes exponentially sparser,
2785-
preserving long-term trends while bounding total row count.
2792+
"""Two-pass prune:
2793+
2794+
1. Evict all history for tasks that have been in a terminal state
2795+
longer than TASK_RESOURCE_HISTORY_TERMINAL_TTL. Dashboards read
2796+
peak memory from tasks.peak_memory_mb after termination; the
2797+
per-sample rows are dead weight and are ~85% of the table on
2798+
prod.
2799+
2. Logarithmic downsampling for anything that remains: when a
2800+
(task, attempt) exceeds 2*N rows, thin the older half by deleting
2801+
every other row so older data grows exponentially sparser.
2802+
2803+
Deletes are chunked so the writer lock releases between chunks.
27862804
"""
2805+
now_ms = Timestamp.now().epoch_ms()
2806+
ttl_cutoff_ms = now_ms - TASK_RESOURCE_HISTORY_TERMINAL_TTL.to_ms()
2807+
terminal_placeholders = ",".join("?" for _ in TERMINAL_TASK_STATES)
2808+
2809+
evicted_terminal = 0
2810+
with self._db.transaction() as cur:
2811+
terminal_ids = [
2812+
str(r["task_id"])
2813+
for r in cur.execute(
2814+
f"SELECT task_id FROM tasks "
2815+
f"WHERE state IN ({terminal_placeholders}) "
2816+
f"AND finished_at_ms IS NOT NULL AND finished_at_ms < ?",
2817+
(*TERMINAL_TASK_STATES, ttl_cutoff_ms),
2818+
).fetchall()
2819+
]
2820+
for chunk_start in range(0, len(terminal_ids), TASK_RESOURCE_HISTORY_DELETE_CHUNK):
2821+
chunk = terminal_ids[chunk_start : chunk_start + TASK_RESOURCE_HISTORY_DELETE_CHUNK]
2822+
ph = ",".join("?" * len(chunk))
2823+
cur.execute(f"DELETE FROM task_resource_history WHERE task_id IN ({ph})", tuple(chunk))
2824+
evicted_terminal += cur.rowcount
2825+
27872826
threshold = TASK_RESOURCE_HISTORY_RETENTION * 2
27882827
with self._db.transaction() as cur:
27892828
overflows = cur.execute(
@@ -2814,9 +2853,11 @@ def prune_task_resource_history(self) -> int:
28142853
ph = ",".join("?" * len(chunk))
28152854
cur.execute(f"DELETE FROM task_resource_history WHERE id IN ({ph})", tuple(chunk))
28162855
total_deleted += cur.rowcount
2856+
if evicted_terminal > 0:
2857+
logger.info("Evicted %d task_resource_history rows (terminal TTL)", evicted_terminal)
28172858
if total_deleted > 0:
28182859
logger.info("Pruned %d task_resource_history rows (log downsampling)", total_deleted)
2819-
return total_deleted
2860+
return evicted_terminal + total_deleted
28202861

28212862
def _batch_delete(
28222863
self,

0 commit comments

Comments
 (0)