Skip to content

Commit 3c9e438

Browse files
authored
fix(iris): log perf, scheduling fixes, holder task device constraints (marin-community#3369)
Log performance: batch log fetching replaces per-line streaming, substring filters run server-side, reduced default fetch size for faster UI response. Scheduling performance: cap non-coscheduled tasks per job per cycle (max_tasks_per_job_per_cycle=4) to bound scheduler CPU time and prevent GIL starvation of the heartbeat thread. Reservation holder constraint fix: holder jobs were created without device-type/variant constraints because they bypass the service layer's _inject_resource_constraints(). When the reservation entry had explicit constraints (e.g. region) or no constraints at all, device constraints were missing entirely. The holder could land on any worker (e.g. a v6e-4 when it needs v5p-64), the reservation would never be satisfied, and the parent job would sit pending forever. Now constraints_from_resources() + merge_constraints() auto-inject device constraints on holder creation, matching what _worker_matches_reservation_entry already does for claim matching. Preference pass fix: holder tasks have job_id parent/:reservation: but claims are keyed by the parent's wire ID. The preference pass now resolves holder tasks through their parent so they route to claimed workers after worker death + requeue. Tests: 3 new tests covering device constraint injection, wrong-device-type scheduling rejection, and preference pass parent resolution for holders.
1 parent 7bbcbb8 commit 3c9e438

33 files changed

+1213
-529
lines changed

lib/iris/src/iris/actor/server.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ def serve_background(self, port: int | None = None) -> int:
225225
host=self._host,
226226
port=self._actual_port,
227227
log_level="error",
228+
log_config=None,
228229
timeout_keep_alive=120,
229230
)
230231
self._server = uvicorn.Server(config)

lib/iris/src/iris/client/client.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -817,7 +817,7 @@ def fetch_task_logs(
817817
include_children: bool = False,
818818
start: Timestamp | None = None,
819819
max_lines: int = 0,
820-
regex: str | None = None,
820+
substring: str | None = None,
821821
attempt_id: int = -1,
822822
min_level: str = "",
823823
) -> list[TaskLogEntry]:
@@ -828,7 +828,7 @@ def fetch_task_logs(
828828
include_children: Include logs from child jobs (job ID only)
829829
start: Only return logs after this timestamp (None = from beginning)
830830
max_lines: Maximum number of log lines to return (0 = unlimited)
831-
regex: Regex filter for log content
831+
substring: Substring filter for log content
832832
attempt_id: Filter to specific attempt (-1 = all attempts)
833833
min_level: Minimum log level filter (DEBUG/INFO/WARNING/ERROR/CRITICAL)
834834
@@ -840,7 +840,7 @@ def fetch_task_logs(
840840
include_children=include_children,
841841
since_ms=start.epoch_ms() if start else 0,
842842
max_total_lines=max_lines,
843-
regex=regex,
843+
substring=substring,
844844
attempt_id=attempt_id,
845845
min_level=min_level,
846846
)

lib/iris/src/iris/cluster/client/protocol.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def fetch_task_logs(
104104
include_children: bool = False,
105105
since_ms: int = 0,
106106
max_total_lines: int = 0,
107-
regex: str | None = None,
107+
substring: str | None = None,
108108
attempt_id: int = -1,
109109
cursor: int = 0,
110110
min_level: str = "",

lib/iris/src/iris/cluster/client/remote_client.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ def fetch_task_logs(
340340
include_children: bool = False,
341341
since_ms: int = 0,
342342
max_total_lines: int = 0,
343-
regex: str | None = None,
343+
substring: str | None = None,
344344
attempt_id: int = -1,
345345
cursor: int = 0,
346346
min_level: str = "",
@@ -352,7 +352,7 @@ def fetch_task_logs(
352352
include_children: Include logs from child jobs (job ID only)
353353
since_ms: Only return logs after this timestamp (exclusive)
354354
max_total_lines: Maximum total lines (0 = default 10000)
355-
regex: Regex filter for log content
355+
substring: Substring filter for log content
356356
attempt_id: Filter to specific attempt (-1 = all attempts)
357357
cursor: Autoincrement id cursor for incremental polling
358358
min_level: Minimum log level filter (DEBUG/INFO/WARNING/ERROR/CRITICAL)
@@ -362,7 +362,7 @@ def fetch_task_logs(
362362
include_children=include_children,
363363
since_ms=since_ms,
364364
max_total_lines=max_total_lines,
365-
regex=regex or "",
365+
substring=substring or "",
366366
attempt_id=attempt_id,
367367
cursor=cursor,
368368
min_level=min_level,

lib/iris/src/iris/cluster/controller/controller.py

Lines changed: 72 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,18 @@
1717
import uvicorn
1818

1919
from iris.chaos import chaos
20+
from iris.cluster.constraints import (
21+
AttributeValue,
22+
Constraint,
23+
PlacementRequirements,
24+
WellKnownAttribute,
25+
constraints_from_resources,
26+
evaluate_constraint,
27+
extract_placement_requirements,
28+
merge_constraints,
29+
)
2030
from iris.cluster.controller.autoscaler import Autoscaler, DemandEntry
2131
from iris.cluster.controller.dashboard import ControllerDashboard
22-
from iris.cluster.log_store import LogStoreHandler, PROCESS_LOG_KEY
2332
from iris.cluster.controller.events import TaskAssignedEvent, TaskStateChangedEvent
2433
from iris.cluster.controller.scheduler import (
2534
JobRequirements,
@@ -28,40 +37,32 @@
2837
WorkerSnapshot,
2938
)
3039
from iris.cluster.controller.service import ControllerServiceImpl
40+
from iris.cluster.controller.snapshot import (
41+
SnapshotResult,
42+
create_snapshot,
43+
read_latest_snapshot,
44+
restore_scaling_group,
45+
restore_snapshot,
46+
restore_tracked_workers,
47+
write_snapshot,
48+
)
3149
from iris.cluster.controller.state import (
3250
HEARTBEAT_FAILURE_THRESHOLD,
51+
RESERVATION_HOLDER_JOB_NAME,
3352
ControllerJob,
3453
ControllerState,
3554
ControllerTask,
3655
ControllerWorker,
3756
HeartbeatSnapshot,
3857
ReservationClaim,
3958
)
40-
from iris.cluster.constraints import (
41-
AttributeValue,
42-
Constraint,
43-
PlacementRequirements,
44-
WellKnownAttribute,
45-
constraints_from_resources,
46-
evaluate_constraint,
47-
extract_placement_requirements,
48-
merge_constraints,
49-
)
59+
from iris.cluster.log_store import PROCESS_LOG_KEY, LogStoreHandler
5060
from iris.cluster.types import (
5161
JobName,
5262
VmWorkerStatus,
5363
VmWorkerStatusMap,
5464
WorkerId,
5565
)
56-
from iris.cluster.controller.snapshot import (
57-
SnapshotResult,
58-
create_snapshot,
59-
read_latest_snapshot,
60-
restore_scaling_group,
61-
restore_snapshot,
62-
restore_tracked_workers,
63-
write_snapshot,
64-
)
6566
from iris.logging import slow_log
6667
from iris.managed_thread import ManagedThread, ThreadContainer, get_thread_container
6768
from iris.rpc import cluster_pb2, snapshot_pb2
@@ -145,13 +146,15 @@ def compute_demand_entries(
145146
# Also track which jobs have reservations so we can apply taint injection.
146147
jobs: dict[JobName, JobRequirements] = {}
147148
has_reservation: set[JobName] = set()
149+
has_direct_reservation: set[JobName] = set()
148150
for task in all_schedulable:
149151
if task.job_id not in jobs:
150152
job = state.get_job(task.job_id)
151153
if job:
152154
jobs[task.job_id] = job_requirements_from_job(job)
153155
if job.request.HasField("reservation"):
154156
has_reservation.add(task.job_id)
157+
has_direct_reservation.add(task.job_id)
155158
elif _find_reservation_ancestor(state, task.job_id) is not None:
156159
has_reservation.add(task.job_id)
157160

@@ -163,7 +166,7 @@ def compute_demand_entries(
163166
task_ids = [t.task_id for t in all_schedulable]
164167
claims = reservation_claims or {}
165168
dry_run_workers = _inject_reservation_taints(workers, claims)
166-
dry_run_jobs = _inject_taint_constraints(jobs, has_reservation)
169+
dry_run_jobs = _inject_taint_constraints(jobs, has_reservation, has_direct_reservation)
167170

168171
context = scheduler.create_scheduling_context(
169172
dry_run_workers,
@@ -289,24 +292,42 @@ def _inject_reservation_taints(
289292
def _inject_taint_constraints(
290293
jobs: dict[JobName, JobRequirements],
291294
has_reservation: set[JobName],
295+
has_direct_reservation: set[JobName] | None = None,
292296
) -> dict[JobName, JobRequirements]:
293-
"""Add NOT_EXISTS reservation-job constraint to non-reservation jobs.
294-
295-
This prevents normal jobs from being scheduled onto claimed workers.
296-
Reservation jobs are left unchanged — they can use both claimed and
297-
unclaimed workers (the reservation is a floor, not a ceiling).
297+
"""Add reservation taint constraints to jobs.
298+
299+
Three-way logic:
300+
- Direct reservation jobs (has_direct_reservation): get an EQ constraint
301+
forcing them onto their claimed workers only.
302+
- Descendants of reservation jobs (has_reservation minus direct): no
303+
constraint — they can use both claimed and unclaimed workers.
304+
- Non-reservation jobs: get a NOT_EXISTS constraint blocking them from
305+
claimed workers.
298306
"""
299307
if not has_reservation and not jobs:
300308
return jobs
301309

310+
if has_direct_reservation is None:
311+
has_direct_reservation = set()
312+
302313
taint_constraint = cluster_pb2.Constraint(
303314
key=RESERVATION_TAINT_KEY,
304315
op=cluster_pb2.CONSTRAINT_OP_NOT_EXISTS,
305316
)
306317

307318
modified: dict[JobName, JobRequirements] = {}
308319
for job_id, req in jobs.items():
309-
if job_id in has_reservation:
320+
if job_id in has_direct_reservation:
321+
eq_constraint = cluster_pb2.Constraint(
322+
key=RESERVATION_TAINT_KEY,
323+
op=cluster_pb2.CONSTRAINT_OP_EQ,
324+
value=cluster_pb2.AttributeValue(string_value=job_id.to_wire()),
325+
)
326+
modified[job_id] = replace(
327+
req,
328+
constraints=[*list(req.constraints), eq_constraint],
329+
)
330+
elif job_id in has_reservation:
310331
modified[job_id] = req
311332
else:
312333
modified[job_id] = replace(
@@ -417,7 +438,14 @@ def _preference_pass(
417438
continue
418439

419440
job_wire = job_id.to_wire()
420-
for wid in claimed_by_job.get(job_wire, ()):
441+
# Holder jobs are children of the reservation job — look up claims
442+
# under the parent's wire ID.
443+
claim_key = job_wire
444+
if RESERVATION_HOLDER_JOB_NAME in job_wire:
445+
parent = job_id.parent
446+
if parent is not None:
447+
claim_key = parent.to_wire()
448+
for wid in claimed_by_job.get(claim_key, ()):
421449
if context.assignment_counts.get(wid, 0) >= context.max_assignments_per_worker:
422450
continue
423451
capacity = context.capacities.get(wid)
@@ -494,6 +522,12 @@ class ControllerConfig:
494522
max_dispatch_parallelism: int = 32
495523
"""Maximum number of concurrent RPC dispatch operations."""
496524

525+
max_tasks_per_job_per_cycle: int = 4
526+
"""Maximum tasks from a single non-coscheduled job to consider per scheduling
527+
cycle. Bounds CPU time in the scheduler when many tasks are pending, preventing
528+
GIL starvation of the heartbeat thread. Coscheduled jobs are exempt (they need
529+
all tasks for atomic assignment). Set to 0 for unlimited."""
530+
497531
heartbeat_failure_threshold: int = HEARTBEAT_FAILURE_THRESHOLD
498532
"""Consecutive heartbeat failures before marking worker as dead."""
499533

@@ -656,6 +690,7 @@ def start(self) -> None:
656690
host=self._config.host,
657691
port=self._config.port,
658692
log_level="warning",
693+
log_config=None,
659694
timeout_keep_alive=120,
660695
)
661696
self._server = uvicorn.Server(server_config)
@@ -847,9 +882,13 @@ def _run_scheduling(self) -> None:
847882

848883
# Handle timeouts and reservation gates before scheduling.
849884
# Holder tasks participate in scheduling like normal tasks.
885+
# Cap non-coscheduled tasks per job to bound scheduling CPU time.
850886
schedulable_task_ids: list[JobName] = []
851887
jobs: dict[JobName, JobRequirements] = {}
852888
has_reservation: set[JobName] = set()
889+
has_direct_reservation: set[JobName] = set()
890+
tasks_per_job: dict[JobName, int] = defaultdict(int)
891+
cap = self._config.max_tasks_per_job_per_cycle
853892
for task in pending_tasks:
854893
if not task.can_be_scheduled():
855894
continue
@@ -863,11 +902,15 @@ def _run_scheduling(self) -> None:
863902
# Holder tasks are always schedulable (they ARE the reservation).
864903
if not job.is_reservation_holder and not self._is_reservation_satisfied(job):
865904
continue
905+
if cap > 0 and not job.is_coscheduled and tasks_per_job[task.job_id] >= cap:
906+
continue
907+
tasks_per_job[task.job_id] += 1
866908
schedulable_task_ids.append(task.task_id)
867909
if task.job_id not in jobs:
868910
jobs[task.job_id] = job_requirements_from_job(job)
869911
if job.request.HasField("reservation"):
870912
has_reservation.add(task.job_id)
913+
has_direct_reservation.add(task.job_id)
871914
elif _find_reservation_ancestor(self._state, task.job_id) is not None:
872915
has_reservation.add(task.job_id)
873916

@@ -877,7 +920,7 @@ def _run_scheduling(self) -> None:
877920
# Inject reservation taints: claimed workers get a taint attribute,
878921
# non-reservation jobs get a NOT_EXISTS constraint for it.
879922
modified_workers = _inject_reservation_taints(workers, self._reservation_claims)
880-
jobs = _inject_taint_constraints(jobs, has_reservation)
923+
jobs = _inject_taint_constraints(jobs, has_reservation, has_direct_reservation)
881924

882925
with slow_log(logger, "snapshot_building_counts", threshold_ms=50):
883926
building_counts = self._state.snapshot_building_counts()

lib/iris/src/iris/cluster/controller/scaling_group.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,10 +173,19 @@ def _lifecycle_to_vm_state(lifecycle: SliceLifecycleState) -> vm_pb2.VmState:
173173
}[lifecycle]
174174

175175

176-
def slice_state_to_proto(state: SliceState) -> vm_pb2.SliceInfo:
176+
def slice_state_to_proto(state: SliceState, idle_threshold: Duration | None = None) -> vm_pb2.SliceInfo:
177177
"""Convert a SliceState to a SliceInfo proto for RPC APIs."""
178178
created_at = state.handle.created_at
179179
vm_state = _lifecycle_to_vm_state(state.lifecycle)
180+
181+
is_idle = False
182+
if idle_threshold is not None and state.lifecycle == SliceLifecycleState.READY:
183+
if state.last_active.epoch_ms() == 0:
184+
is_idle = True
185+
else:
186+
idle_duration = Duration.from_ms(Timestamp.now().epoch_ms() - state.last_active.epoch_ms())
187+
is_idle = idle_duration >= idle_threshold
188+
180189
return vm_pb2.SliceInfo(
181190
slice_id=state.handle.slice_id,
182191
scale_group=state.handle.scale_group,
@@ -192,6 +201,8 @@ def slice_state_to_proto(state: SliceState) -> vm_pb2.SliceInfo:
192201
for i, addr in enumerate(state.vm_addresses)
193202
],
194203
error_message=state.error_message,
204+
last_active=state.last_active.to_proto(),
205+
idle=is_idle,
195206
)
196207

197208

@@ -997,7 +1008,8 @@ def to_status(self) -> vm_pb2.ScaleGroupStatus:
9971008
availability_reason=availability.reason,
9981009
blocked_until=blocked_until.to_proto(),
9991010
scale_up_cooldown_until=cooldown_until.to_proto(),
1000-
slices=[slice_state_to_proto(state) for state in snapshot],
1011+
slices=[slice_state_to_proto(state, idle_threshold=self._idle_threshold) for state in snapshot],
1012+
idle_threshold_ms=self._idle_threshold.to_ms(),
10011013
)
10021014
for state_name, count in counts.items():
10031015
status.slice_state_counts[state_name] = count

0 commit comments

Comments
 (0)