Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
281 changes: 276 additions & 5 deletions lib/iris/scripts/benchmark_db_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,20 @@ def clone_db(source: ControllerDB) -> ControllerDB:
clone_path = clone_dir / ControllerDB.DB_FILENAME
conn = sqlite3.connect(str(clone_path))
conn.execute("ATTACH DATABASE ? AS src", (str(source.db_path),))
# Copy schema + data for each table
for table in _CLONE_TABLES:
conn.execute(f"CREATE TABLE {table} AS SELECT * FROM src.{table}")
# Copy indexes from source schema

# Use the source's real CREATE TABLE DDL — CREATE TABLE AS SELECT drops
# UNIQUE/PRIMARY KEY/CHECK constraints, which breaks UPSERT paths like
# register_worker's INSERT ... ON CONFLICT.
clone_tables = set(_CLONE_TABLES)
table_ddl = conn.execute("SELECT name, sql FROM src.sqlite_master WHERE type='table' AND sql IS NOT NULL").fetchall()
for name, sql in table_ddl:
if name not in clone_tables:
continue
conn.execute(sql)
conn.execute(f"INSERT INTO {name} SELECT * FROM src.{name}")

# Copy indexes from source schema (skip autoindexes — those come from
# UNIQUE/PK constraints already in the CREATE TABLE).
rows = conn.execute("SELECT sql FROM src.sqlite_master WHERE type='index' AND sql IS NOT NULL").fetchall()
for row in rows:
try:
Expand All @@ -138,6 +148,7 @@ def clone_db(source: ControllerDB) -> ControllerDB:
conn.execute(row[0])
except sqlite3.OperationalError:
pass
conn.commit()
conn.execute("DETACH DATABASE src")
conn.execute("ANALYZE")
conn.close()
Expand Down Expand Up @@ -1216,6 +1227,261 @@ def _burst_100_contended():
hb_thread.join(timeout=10.0)


def _build_heartbeat_requests(db: ControllerDB) -> list[HeartbeatApplyRequest]:
"""Build a heartbeat batch shaped like a live provider-sync round:
one HeartbeatApplyRequest per active worker, with one RUNNING
resource-usage update per task currently assigned to that worker.
"""
workers = healthy_active_workers_with_attributes(db)
active_states = tuple(ACTIVE_TASK_STATES)
snapshot_proto = job_pb2.WorkerResourceSnapshot()
usage = job_pb2.ResourceUsage(cpu_millicores=1000, memory_mb=1024)
requests: list[HeartbeatApplyRequest] = []
for w in workers:
wid = str(w.worker_id)
rows = db.fetchall(
"SELECT task_id, current_attempt_id FROM tasks " "WHERE current_worker_id = ? AND state IN (?, ?, ?)",
(wid, *active_states),
)
updates = [
TaskUpdate(
task_id=JobName.from_wire(str(r["task_id"])),
attempt_id=int(r["current_attempt_id"]),
new_state=job_pb2.TASK_STATE_RUNNING,
resource_usage=usage,
)
for r in rows
]
requests.append(
HeartbeatApplyRequest(
worker_id=WorkerId(wid),
worker_resource_snapshot=snapshot_proto,
updates=updates,
)
)
return requests


def _build_failure_batch(db: ControllerDB, n: int) -> list[tuple[DispatchBatch, str]]:
rows = db.fetchall(
"SELECT worker_id, address FROM workers WHERE active = 1 LIMIT ?",
(n,),
)
return [
(
DispatchBatch(
worker_id=WorkerId(str(r["worker_id"])),
worker_address=str(r["address"]) if r["address"] is not None else None,
running_tasks=[],
),
"benchmark: simulated provider-sync failure",
)
for r in rows
]


def _print_latency_distribution(name: str, latencies: list[float]) -> None:
if not latencies:
print(f" {name:60s} (no samples)")
return
latencies.sort()
p50 = latencies[len(latencies) // 2]
p95 = latencies[int(len(latencies) * 0.95)]
p99 = latencies[int(len(latencies) * 0.99)]
max_ms = latencies[-1]
_results.append((name, p50, p95, len(latencies)))
print(
f" {name:60s} n={len(latencies):3d} "
f"p50={p50:7.1f}ms p95={p95:8.1f}ms p99={p99:8.1f}ms max={max_ms:8.1f}ms"
)


def _run_apply_under_contention(
*,
name: str,
write_db: ControllerDB,
write_txns: ControllerTransitions,
heartbeat_requests: list[HeartbeatApplyRequest],
fail_threads: int = 0,
fail_n: int = 50,
fail_chunk: int = 50,
fail_interval_s: float = 2.0,
register_threads: int = 0,
register_burst: int = 100,
endpoint_threads: int = 0,
checkpoint_thread: bool = False,
synchronous_normal: bool = False,
duration_s: float = 8.0,
) -> None:
"""Run apply_heartbeats_batch repeatedly on a victim thread while
configurable write storms hammer the same clone DB. Report p50/p95/p99/max
of the victim's per-call latency.
"""
if synchronous_normal:
# PRAGMA synchronous can't be changed mid-connection once a tx has run,
# so issue it on a fresh raw connection to the clone file. It persists
# for that connection only; our ControllerDB connection is unaffected,
# which is the point — prod can't change synchronous mid-flight either.
_raw = sqlite3.connect(str(write_db.db_path))
_raw.execute("PRAGMA synchronous=NORMAL")
_raw.close()

endpoint_tasks_rows = write_db.fetchall(
"SELECT task_id FROM tasks WHERE state IN (1,2,3,9) AND current_attempt_id IS NOT NULL LIMIT 200"
)
endpoint_tasks = [JobName.from_wire(str(r["task_id"])) for r in endpoint_tasks_rows]

stop = threading.Event()
victim_latencies: list[float] = []
errors: list[BaseException] = []

def _victim():
try:
while not stop.is_set():
t0 = time.perf_counter()
write_txns.apply_heartbeats_batch(heartbeat_requests)
victim_latencies.append((time.perf_counter() - t0) * 1000)
except BaseException as e:
errors.append(e)

def _fail_storm():
try:
while not stop.is_set():
failures = _build_failure_batch(write_db, fail_n)
if failures:
write_txns.fail_heartbeats_batch(failures, force_remove=True, chunk_size=fail_chunk)
stop.wait(fail_interval_s)
except BaseException as e:
errors.append(e)

def _register_storm():
try:
meta = _build_sample_worker_metadata()
while not stop.is_set():
base = f"bench-contend-{uuid.uuid4().hex[:8]}"
for i in range(register_burst):
write_txns.register_worker(
worker_id=WorkerId(f"{base}-{i}"),
address=f"tcp://{base}-{i}:1234",
metadata=meta,
ts=Timestamp.now(),
slice_id="",
scale_group="bench",
)
if stop.is_set():
break
except BaseException as e:
errors.append(e)

def _endpoint_storm():
try:
i = 0
while not stop.is_set():
t = endpoint_tasks[i % len(endpoint_tasks)]
write_txns.add_endpoint(_make_endpoint(t))
i += 1
except BaseException as e:
errors.append(e)

def _checkpoint_loop():
try:
while not stop.is_set():
try:
write_db.execute("PRAGMA wal_checkpoint(TRUNCATE)")
except sqlite3.OperationalError:
pass
stop.wait(1.0)
except BaseException as e:
errors.append(e)

threads: list[threading.Thread] = [threading.Thread(target=_victim, name="victim")]
for _ in range(fail_threads):
threads.append(threading.Thread(target=_fail_storm, name="fail"))
for _ in range(register_threads):
threads.append(threading.Thread(target=_register_storm, name="register"))
for _ in range(endpoint_threads):
threads.append(threading.Thread(target=_endpoint_storm, name="endpoint"))
if checkpoint_thread:
threads.append(threading.Thread(target=_checkpoint_loop, name="checkpoint"))

for t in threads:
t.start()
time.sleep(duration_s)
stop.set()
for t in threads:
t.join(timeout=30.0)

if errors:
print(f" {name}: background thread error: {errors[0]!r}")
_print_latency_distribution(name, victim_latencies)


def benchmark_apply_contention(db: ControllerDB) -> None:
"""Reproduce the production 'apply results' multi-second tail by running
apply_heartbeats_batch as the victim under concurrent write storms.
"""
heartbeat_requests = _build_heartbeat_requests(db)
total_tasks = sum(len(r.updates) for r in heartbeat_requests)
print(f" (victim heartbeat batch: {len(heartbeat_requests)} workers, {total_tasks} tasks)")

if not heartbeat_requests:
print(" (skipped, no workers)")
return

scenarios = [
dict(name="apply @ baseline (no contention)"),
dict(name="apply + 1x fail_heartbeats_batch", fail_threads=1),
dict(name="apply + 1x register_worker burst", register_threads=1),
dict(name="apply + 1x add_endpoint storm", endpoint_threads=1),
dict(
name="apply + prod-mix (fail + register + endpoint)",
fail_threads=1,
register_threads=1,
endpoint_threads=1,
),
dict(
name="apply + heavy storm (2f/2r/2e, chunk=200, 0.5s)",
fail_threads=2,
fail_chunk=200,
fail_interval_s=0.5,
register_threads=2,
endpoint_threads=2,
),
dict(
name="apply + heavy + forced WAL checkpoints",
fail_threads=2,
fail_chunk=200,
fail_interval_s=0.5,
register_threads=2,
endpoint_threads=2,
checkpoint_thread=True,
),
dict(
name="apply + heavy + synchronous=NORMAL",
fail_threads=2,
fail_chunk=200,
fail_interval_s=0.5,
register_threads=2,
endpoint_threads=2,
synchronous_normal=True,
),
]

write_db = clone_db(db)
write_txns = ControllerTransitions(write_db)
try:
for scenario in scenarios:
_run_apply_under_contention(
write_db=write_db,
write_txns=write_txns,
heartbeat_requests=heartbeat_requests,
**scenario,
)
finally:
write_db.close()
shutil.rmtree(write_db._db_dir, ignore_errors=True)


def print_summary() -> None:
print("\n" + "=" * 80)
print(f" {'Query':50s} {'p50':>10s} {'p95':>10s} {'n':>5s}")
Expand Down Expand Up @@ -1263,7 +1529,7 @@ def _ensure_db(db_path: Path | None) -> Path:
@click.option(
"--only",
"only_group",
type=click.Choice(["scheduling", "dashboard", "heartbeat", "endpoints"]),
type=click.Choice(["scheduling", "dashboard", "heartbeat", "endpoints", "apply_contention"]),
help="Run only this group",
)
@click.option("--no-analyze", is_flag=True, help="Skip ANALYZE to test unoptimized query plans")
Expand Down Expand Up @@ -1309,6 +1575,11 @@ def main(db_path: Path | None, only_group: str | None, no_analyze: bool, fresh:
if only_group is None or only_group == "endpoints":
print("[endpoints]")
benchmark_endpoints(db)
print()

if only_group == "apply_contention":
print("[apply_contention]")
benchmark_apply_contention(db)

print_summary()
db.close()
Expand Down
53 changes: 47 additions & 6 deletions lib/iris/src/iris/cluster/controller/transitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,17 @@ class ReservationClaim:
"""Maximum task_resource_history rows retained per (task_id, attempt_id).
Logarithmic downsampling triggers at 2x this value."""

TASK_RESOURCE_HISTORY_TERMINAL_TTL = Duration.from_hours(1)
"""After a task reaches a terminal state, its resource history is fully
evicted this long after the finish timestamp. Dashboards surface peak
memory from tasks.peak_memory_mb once a task is done; retaining per-sample
rows forever bloats the DB (~85% of task_resource_history on prod is for
terminal tasks) and amplifies writer contention during heartbeat batches."""

TASK_RESOURCE_HISTORY_DELETE_CHUNK = 5000
"""Maximum ids per DELETE in prune_task_resource_history — bounds how long
the writer lock is held per chunk so other RPCs can interleave."""

DIRECT_PROVIDER_PROMOTION_RATE = 128
"""Token bucket capacity for task promotion (pods per minute).

Expand Down Expand Up @@ -2778,12 +2789,40 @@ def prune_worker_resource_history(self) -> int:
)

def prune_task_resource_history(self) -> int:
"""Logarithmic downsampling: when a (task, attempt) exceeds 2*N rows,
thin the older half by deleting every other row.

Over repeated compaction cycles older data becomes exponentially sparser,
preserving long-term trends while bounding total row count.
"""Two-pass prune:

1. Evict all history for tasks that have been in a terminal state
longer than TASK_RESOURCE_HISTORY_TERMINAL_TTL. Dashboards read
peak memory from tasks.peak_memory_mb after termination; the
per-sample rows are dead weight and are ~85% of the table on
prod.
2. Logarithmic downsampling for anything that remains: when a
(task, attempt) exceeds 2*N rows, thin the older half by deleting
every other row so older data grows exponentially sparser.

Deletes are chunked so the writer lock releases between chunks.
"""
now_ms = Timestamp.now().epoch_ms()
ttl_cutoff_ms = now_ms - TASK_RESOURCE_HISTORY_TERMINAL_TTL.to_ms()
terminal_placeholders = ",".join("?" for _ in TERMINAL_TASK_STATES)

evicted_terminal = 0
with self._db.transaction() as cur:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Commit each TTL delete chunk in separate transaction

prune_task_resource_history wraps the entire terminal-TTL eviction in one self._db.transaction() block, so all chunked DELETE ... IN (...) statements run under a single BEGIN IMMEDIATE write lock. Because ControllerDB.transaction() holds that lock until the context exits, the new chunking does not actually let other RPC writes interleave; with large terminal task sets this can still block heartbeat/scheduling writes for the full eviction duration and recreate multi-second contention spikes.

Useful? React with 👍 / 👎.

terminal_ids = [
str(r["task_id"])
for r in cur.execute(
f"SELECT task_id FROM tasks "
f"WHERE state IN ({terminal_placeholders}) "
f"AND finished_at_ms IS NOT NULL AND finished_at_ms < ?",
(*TERMINAL_TASK_STATES, ttl_cutoff_ms),
).fetchall()
]
for chunk_start in range(0, len(terminal_ids), TASK_RESOURCE_HISTORY_DELETE_CHUNK):
chunk = terminal_ids[chunk_start : chunk_start + TASK_RESOURCE_HISTORY_DELETE_CHUNK]
ph = ",".join("?" * len(chunk))
cur.execute(f"DELETE FROM task_resource_history WHERE task_id IN ({ph})", tuple(chunk))
evicted_terminal += cur.rowcount

threshold = TASK_RESOURCE_HISTORY_RETENTION * 2
with self._db.transaction() as cur:
overflows = cur.execute(
Expand Down Expand Up @@ -2814,9 +2853,11 @@ def prune_task_resource_history(self) -> int:
ph = ",".join("?" * len(chunk))
cur.execute(f"DELETE FROM task_resource_history WHERE id IN ({ph})", tuple(chunk))
total_deleted += cur.rowcount
if evicted_terminal > 0:
logger.info("Evicted %d task_resource_history rows (terminal TTL)", evicted_terminal)
if total_deleted > 0:
logger.info("Pruned %d task_resource_history rows (log downsampling)", total_deleted)
return total_deleted
return evicted_terminal + total_deleted

def _batch_delete(
self,
Expand Down
Loading
Loading