@@ -121,10 +121,20 @@ def clone_db(source: ControllerDB) -> ControllerDB:
121121 clone_path = clone_dir / ControllerDB .DB_FILENAME
122122 conn = sqlite3 .connect (str (clone_path ))
123123 conn .execute ("ATTACH DATABASE ? AS src" , (str (source .db_path ),))
124- # Copy schema + data for each table
125- for table in _CLONE_TABLES :
126- conn .execute (f"CREATE TABLE { table } AS SELECT * FROM src.{ table } " )
127- # Copy indexes from source schema
124+
125+ # Use the source's real CREATE TABLE DDL — CREATE TABLE AS SELECT drops
126+ # UNIQUE/PRIMARY KEY/CHECK constraints, which breaks UPSERT paths like
127+ # register_worker's INSERT ... ON CONFLICT.
128+ clone_tables = set (_CLONE_TABLES )
129+ table_ddl = conn .execute ("SELECT name, sql FROM src.sqlite_master WHERE type='table' AND sql IS NOT NULL" ).fetchall ()
130+ for name , sql in table_ddl :
131+ if name not in clone_tables :
132+ continue
133+ conn .execute (sql )
134+ conn .execute (f"INSERT INTO { name } SELECT * FROM src.{ name } " )
135+
136+ # Copy indexes from source schema (skip autoindexes — those come from
137+ # UNIQUE/PK constraints already in the CREATE TABLE).
128138 rows = conn .execute ("SELECT sql FROM src.sqlite_master WHERE type='index' AND sql IS NOT NULL" ).fetchall ()
129139 for row in rows :
130140 try :
@@ -138,6 +148,7 @@ def clone_db(source: ControllerDB) -> ControllerDB:
138148 conn .execute (row [0 ])
139149 except sqlite3 .OperationalError :
140150 pass
151+ conn .commit ()
141152 conn .execute ("DETACH DATABASE src" )
142153 conn .execute ("ANALYZE" )
143154 conn .close ()
@@ -1216,6 +1227,261 @@ def _burst_100_contended():
12161227 hb_thread .join (timeout = 10.0 )
12171228
12181229
1230+ def _build_heartbeat_requests (db : ControllerDB ) -> list [HeartbeatApplyRequest ]:
1231+ """Build a heartbeat batch shaped like a live provider-sync round:
1232+ one HeartbeatApplyRequest per active worker, with one RUNNING
1233+ resource-usage update per task currently assigned to that worker.
1234+ """
1235+ workers = healthy_active_workers_with_attributes (db )
1236+ active_states = tuple (ACTIVE_TASK_STATES )
1237+ snapshot_proto = job_pb2 .WorkerResourceSnapshot ()
1238+ usage = job_pb2 .ResourceUsage (cpu_millicores = 1000 , memory_mb = 1024 )
1239+ requests : list [HeartbeatApplyRequest ] = []
1240+ for w in workers :
1241+ wid = str (w .worker_id )
1242+ rows = db .fetchall (
1243+ "SELECT task_id, current_attempt_id FROM tasks " "WHERE current_worker_id = ? AND state IN (?, ?, ?)" ,
1244+ (wid , * active_states ),
1245+ )
1246+ updates = [
1247+ TaskUpdate (
1248+ task_id = JobName .from_wire (str (r ["task_id" ])),
1249+ attempt_id = int (r ["current_attempt_id" ]),
1250+ new_state = job_pb2 .TASK_STATE_RUNNING ,
1251+ resource_usage = usage ,
1252+ )
1253+ for r in rows
1254+ ]
1255+ requests .append (
1256+ HeartbeatApplyRequest (
1257+ worker_id = WorkerId (wid ),
1258+ worker_resource_snapshot = snapshot_proto ,
1259+ updates = updates ,
1260+ )
1261+ )
1262+ return requests
1263+
1264+
1265+ def _build_failure_batch (db : ControllerDB , n : int ) -> list [tuple [DispatchBatch , str ]]:
1266+ rows = db .fetchall (
1267+ "SELECT worker_id, address FROM workers WHERE active = 1 LIMIT ?" ,
1268+ (n ,),
1269+ )
1270+ return [
1271+ (
1272+ DispatchBatch (
1273+ worker_id = WorkerId (str (r ["worker_id" ])),
1274+ worker_address = str (r ["address" ]) if r ["address" ] is not None else None ,
1275+ running_tasks = [],
1276+ ),
1277+ "benchmark: simulated provider-sync failure" ,
1278+ )
1279+ for r in rows
1280+ ]
1281+
1282+
1283+ def _print_latency_distribution (name : str , latencies : list [float ]) -> None :
1284+ if not latencies :
1285+ print (f" { name :60s} (no samples)" )
1286+ return
1287+ latencies .sort ()
1288+ p50 = latencies [len (latencies ) // 2 ]
1289+ p95 = latencies [int (len (latencies ) * 0.95 )]
1290+ p99 = latencies [int (len (latencies ) * 0.99 )]
1291+ max_ms = latencies [- 1 ]
1292+ _results .append ((name , p50 , p95 , len (latencies )))
1293+ print (
1294+ f" { name :60s} n={ len (latencies ):3d} "
1295+ f"p50={ p50 :7.1f} ms p95={ p95 :8.1f} ms p99={ p99 :8.1f} ms max={ max_ms :8.1f} ms"
1296+ )
1297+
1298+
1299+ def _run_apply_under_contention (
1300+ * ,
1301+ name : str ,
1302+ write_db : ControllerDB ,
1303+ write_txns : ControllerTransitions ,
1304+ heartbeat_requests : list [HeartbeatApplyRequest ],
1305+ fail_threads : int = 0 ,
1306+ fail_n : int = 50 ,
1307+ fail_chunk : int = 50 ,
1308+ fail_interval_s : float = 2.0 ,
1309+ register_threads : int = 0 ,
1310+ register_burst : int = 100 ,
1311+ endpoint_threads : int = 0 ,
1312+ checkpoint_thread : bool = False ,
1313+ synchronous_normal : bool = False ,
1314+ duration_s : float = 8.0 ,
1315+ ) -> None :
1316+ """Run apply_heartbeats_batch repeatedly on a victim thread while
1317+ configurable write storms hammer the same clone DB. Report p50/p95/p99/max
1318+ of the victim's per-call latency.
1319+ """
1320+ if synchronous_normal :
1321+ # PRAGMA synchronous can't be changed mid-connection once a tx has run,
1322+ # so issue it on a fresh raw connection to the clone file. It persists
1323+ # for that connection only; our ControllerDB connection is unaffected,
1324+ # which is the point — prod can't change synchronous mid-flight either.
1325+ _raw = sqlite3 .connect (str (write_db .db_path ))
1326+ _raw .execute ("PRAGMA synchronous=NORMAL" )
1327+ _raw .close ()
1328+
1329+ endpoint_tasks_rows = write_db .fetchall (
1330+ "SELECT task_id FROM tasks WHERE state IN (1,2,3,9) AND current_attempt_id IS NOT NULL LIMIT 200"
1331+ )
1332+ endpoint_tasks = [JobName .from_wire (str (r ["task_id" ])) for r in endpoint_tasks_rows ]
1333+
1334+ stop = threading .Event ()
1335+ victim_latencies : list [float ] = []
1336+ errors : list [BaseException ] = []
1337+
1338+ def _victim ():
1339+ try :
1340+ while not stop .is_set ():
1341+ t0 = time .perf_counter ()
1342+ write_txns .apply_heartbeats_batch (heartbeat_requests )
1343+ victim_latencies .append ((time .perf_counter () - t0 ) * 1000 )
1344+ except BaseException as e :
1345+ errors .append (e )
1346+
1347+ def _fail_storm ():
1348+ try :
1349+ while not stop .is_set ():
1350+ failures = _build_failure_batch (write_db , fail_n )
1351+ if failures :
1352+ write_txns .fail_heartbeats_batch (failures , force_remove = True , chunk_size = fail_chunk )
1353+ stop .wait (fail_interval_s )
1354+ except BaseException as e :
1355+ errors .append (e )
1356+
1357+ def _register_storm ():
1358+ try :
1359+ meta = _build_sample_worker_metadata ()
1360+ while not stop .is_set ():
1361+ base = f"bench-contend-{ uuid .uuid4 ().hex [:8 ]} "
1362+ for i in range (register_burst ):
1363+ write_txns .register_worker (
1364+ worker_id = WorkerId (f"{ base } -{ i } " ),
1365+ address = f"tcp://{ base } -{ i } :1234" ,
1366+ metadata = meta ,
1367+ ts = Timestamp .now (),
1368+ slice_id = "" ,
1369+ scale_group = "bench" ,
1370+ )
1371+ if stop .is_set ():
1372+ break
1373+ except BaseException as e :
1374+ errors .append (e )
1375+
1376+ def _endpoint_storm ():
1377+ try :
1378+ i = 0
1379+ while not stop .is_set ():
1380+ t = endpoint_tasks [i % len (endpoint_tasks )]
1381+ write_txns .add_endpoint (_make_endpoint (t ))
1382+ i += 1
1383+ except BaseException as e :
1384+ errors .append (e )
1385+
1386+ def _checkpoint_loop ():
1387+ try :
1388+ while not stop .is_set ():
1389+ try :
1390+ write_db .execute ("PRAGMA wal_checkpoint(TRUNCATE)" )
1391+ except sqlite3 .OperationalError :
1392+ pass
1393+ stop .wait (1.0 )
1394+ except BaseException as e :
1395+ errors .append (e )
1396+
1397+ threads : list [threading .Thread ] = [threading .Thread (target = _victim , name = "victim" )]
1398+ for _ in range (fail_threads ):
1399+ threads .append (threading .Thread (target = _fail_storm , name = "fail" ))
1400+ for _ in range (register_threads ):
1401+ threads .append (threading .Thread (target = _register_storm , name = "register" ))
1402+ for _ in range (endpoint_threads ):
1403+ threads .append (threading .Thread (target = _endpoint_storm , name = "endpoint" ))
1404+ if checkpoint_thread :
1405+ threads .append (threading .Thread (target = _checkpoint_loop , name = "checkpoint" ))
1406+
1407+ for t in threads :
1408+ t .start ()
1409+ time .sleep (duration_s )
1410+ stop .set ()
1411+ for t in threads :
1412+ t .join (timeout = 30.0 )
1413+
1414+ if errors :
1415+ print (f" { name } : background thread error: { errors [0 ]!r} " )
1416+ _print_latency_distribution (name , victim_latencies )
1417+
1418+
1419+ def benchmark_apply_contention (db : ControllerDB ) -> None :
1420+ """Reproduce the production 'apply results' multi-second tail by running
1421+ apply_heartbeats_batch as the victim under concurrent write storms.
1422+ """
1423+ heartbeat_requests = _build_heartbeat_requests (db )
1424+ total_tasks = sum (len (r .updates ) for r in heartbeat_requests )
1425+ print (f" (victim heartbeat batch: { len (heartbeat_requests )} workers, { total_tasks } tasks)" )
1426+
1427+ if not heartbeat_requests :
1428+ print (" (skipped, no workers)" )
1429+ return
1430+
1431+ scenarios = [
1432+ dict (name = "apply @ baseline (no contention)" ),
1433+ dict (name = "apply + 1x fail_heartbeats_batch" , fail_threads = 1 ),
1434+ dict (name = "apply + 1x register_worker burst" , register_threads = 1 ),
1435+ dict (name = "apply + 1x add_endpoint storm" , endpoint_threads = 1 ),
1436+ dict (
1437+ name = "apply + prod-mix (fail + register + endpoint)" ,
1438+ fail_threads = 1 ,
1439+ register_threads = 1 ,
1440+ endpoint_threads = 1 ,
1441+ ),
1442+ dict (
1443+ name = "apply + heavy storm (2f/2r/2e, chunk=200, 0.5s)" ,
1444+ fail_threads = 2 ,
1445+ fail_chunk = 200 ,
1446+ fail_interval_s = 0.5 ,
1447+ register_threads = 2 ,
1448+ endpoint_threads = 2 ,
1449+ ),
1450+ dict (
1451+ name = "apply + heavy + forced WAL checkpoints" ,
1452+ fail_threads = 2 ,
1453+ fail_chunk = 200 ,
1454+ fail_interval_s = 0.5 ,
1455+ register_threads = 2 ,
1456+ endpoint_threads = 2 ,
1457+ checkpoint_thread = True ,
1458+ ),
1459+ dict (
1460+ name = "apply + heavy + synchronous=NORMAL" ,
1461+ fail_threads = 2 ,
1462+ fail_chunk = 200 ,
1463+ fail_interval_s = 0.5 ,
1464+ register_threads = 2 ,
1465+ endpoint_threads = 2 ,
1466+ synchronous_normal = True ,
1467+ ),
1468+ ]
1469+
1470+ write_db = clone_db (db )
1471+ write_txns = ControllerTransitions (write_db )
1472+ try :
1473+ for scenario in scenarios :
1474+ _run_apply_under_contention (
1475+ write_db = write_db ,
1476+ write_txns = write_txns ,
1477+ heartbeat_requests = heartbeat_requests ,
1478+ ** scenario ,
1479+ )
1480+ finally :
1481+ write_db .close ()
1482+ shutil .rmtree (write_db ._db_dir , ignore_errors = True )
1483+
1484+
12191485def print_summary () -> None :
12201486 print ("\n " + "=" * 80 )
12211487 print (f" { 'Query' :50s} { 'p50' :>10s} { 'p95' :>10s} { 'n' :>5s} " )
@@ -1263,7 +1529,7 @@ def _ensure_db(db_path: Path | None) -> Path:
12631529@click .option (
12641530 "--only" ,
12651531 "only_group" ,
1266- type = click .Choice (["scheduling" , "dashboard" , "heartbeat" , "endpoints" ]),
1532+ type = click .Choice (["scheduling" , "dashboard" , "heartbeat" , "endpoints" , "apply_contention" ]),
12671533 help = "Run only this group" ,
12681534)
12691535@click .option ("--no-analyze" , is_flag = True , help = "Skip ANALYZE to test unoptimized query plans" )
@@ -1309,6 +1575,11 @@ def main(db_path: Path | None, only_group: str | None, no_analyze: bool, fresh:
13091575 if only_group is None or only_group == "endpoints" :
13101576 print ("[endpoints]" )
13111577 benchmark_endpoints (db )
1578+ print ()
1579+
1580+ if only_group == "apply_contention" :
1581+ print ("[apply_contention]" )
1582+ benchmark_apply_contention (db )
13121583
13131584 print_summary ()
13141585 db .close ()
0 commit comments