Skip to content

Commit 936da9e

Browse files
committed
[iris] Push IrisClient.list_jobs filters into server-side JobQuery
Previously IrisClient.list_jobs() fetched every job in the cluster (paging offset=0, 500, 1000, ... through the full table) and filtered client-side, so `iris job list --state running` scanned the entire jobs table on every invocation. Translate `state` and `prefix` into JobQuery.state_filter and JobQuery.name_filter so the controller prunes in SQL. Prefix is re-checked client-side because name_filter is a substring, not an anchored prefix. Collapses `states: list` to `state:` since no caller passed more than one.
1 parent 28f2170 commit 936da9e

4 files changed

Lines changed: 144 additions & 17 deletions

File tree

lib/iris/src/iris/cli/job.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -953,16 +953,16 @@ def list_jobs(ctx, state: str | None, prefix: str | None, json_output: bool) ->
953953
controller_url = require_controller_url(ctx)
954954
client = IrisClient.remote(controller_url, workspace=Path.cwd(), token_provider=ctx.obj.get("token_provider"))
955955

956-
states: list[job_pb2.JobState] | None = None
956+
state_value: job_pb2.JobState | None = None
957957
if state is not None:
958958
state_lower = state.lower()
959959
if state_lower not in _STATE_MAP:
960960
valid = ", ".join(sorted(_STATE_MAP.keys()))
961961
raise click.UsageError(f"Unknown state '{state}'. Valid states: {valid}")
962-
states = [_STATE_MAP[state_lower]]
962+
state_value = _STATE_MAP[state_lower]
963963

964964
prefix_name = JobName.from_wire(prefix) if prefix else None
965-
jobs = client.list_jobs(states=states, prefix=prefix_name)
965+
jobs = client.list_jobs(state=state_value, prefix=prefix_name)
966966

967967
# Sort by submitted_at descending (most recent first)
968968
jobs.sort(key=lambda j: j.submitted_at.epoch_ms, reverse=True)

lib/iris/src/iris/client/client.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@
5151
TaskAttempt,
5252
adjust_tpu_replicas,
5353
)
54-
from iris.rpc import job_pb2
54+
from iris.rpc import controller_pb2, job_pb2
55+
from iris.rpc.proto_utils import job_state_friendly
5556
from iris.time_proto import timestamp_from_proto
5657
from rigging.timing import Duration, Timestamp
5758

@@ -724,28 +725,36 @@ def terminate(self, job_id: JobName) -> None:
724725
def list_jobs(
725726
self,
726727
*,
727-
states: list[job_pb2.JobState] | None = None,
728+
state: job_pb2.JobState | None = None,
728729
prefix: JobName | None = None,
729730
) -> list[job_pb2.JobStatus]:
730731
"""List jobs with optional filtering.
731732
733+
Filters are pushed down to the server via ``JobQuery`` so the
734+
controller does not page-walk its entire jobs table: ``state`` becomes
735+
``state_filter`` and ``prefix`` becomes a ``name_filter`` substring
736+
match. The prefix is re-validated client-side because ``name_filter``
737+
is a substring, not an anchored prefix.
738+
732739
Args:
733-
states: If provided, only return jobs in these states
740+
state: If provided, only return jobs in this state
734741
prefix: If provided, only return jobs whose JobName starts with this prefix
735742
736743
Returns:
737744
List of JobStatus matching the filters
738745
"""
739-
all_jobs = self._cluster_client.list_jobs()
740-
result = []
741-
for job in all_jobs:
742-
if states is not None and job.state not in states:
743-
continue
744-
job_name = JobName.from_wire(job.job_id)
745-
if prefix is not None and not job_name.to_wire().startswith(prefix.to_wire()):
746-
continue
747-
result.append(job)
748-
return result
746+
query = controller_pb2.Controller.JobQuery()
747+
if state is not None:
748+
query.state_filter = job_state_friendly(state)
749+
if prefix is not None:
750+
query.name_filter = prefix.to_wire()
751+
752+
all_jobs = self._cluster_client.list_jobs(query=query)
753+
if prefix is None:
754+
return list(all_jobs)
755+
756+
prefix_wire = prefix.to_wire()
757+
return [job for job in all_jobs if JobName.from_wire(job.job_id).to_wire().startswith(prefix_wire)]
749758

750759
def terminate_prefix(
751760
self,

lib/iris/src/iris/cluster/client/protocol.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,12 @@ def list_endpoints(self, prefix: str, *, exact: bool = False) -> list[controller
8080

8181
def list_workers(self) -> list[controller_pb2.Controller.WorkerHealthStatus]: ...
8282

83-
def list_jobs(self) -> list[job_pb2.JobStatus]: ...
83+
def list_jobs(
84+
self,
85+
*,
86+
query: controller_pb2.Controller.JobQuery | None = None,
87+
page_size: int = 500,
88+
) -> list[job_pb2.JobStatus]: ...
8489

8590
def get_task_status(self, task_name: JobName) -> job_pb2.TaskStatus: ...
8691

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# Copyright The Marin Authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Verify IrisClient.list_jobs pushes filters into the server JobQuery.
5+
6+
Server-side filtering matters: without it, a simple
7+
``iris job list --state running`` triggers a full-table paginated scan
8+
(offset=0, 500, 1000, ...) which is expensive on large deployments.
9+
"""
10+
11+
from __future__ import annotations
12+
13+
from dataclasses import dataclass, field
14+
from typing import Any
15+
16+
import pytest
17+
18+
from iris.client import IrisClient
19+
from iris.cluster.types import JobName
20+
from iris.rpc import controller_pb2, job_pb2
21+
22+
23+
@dataclass
24+
class _RecordingClusterClient:
25+
"""Captures the JobQuery passed into list_jobs; returns canned jobs."""
26+
27+
jobs: list[job_pb2.JobStatus] = field(default_factory=list)
28+
captured_queries: list[controller_pb2.Controller.JobQuery] = field(default_factory=list)
29+
30+
def list_jobs(
31+
self,
32+
*,
33+
query: controller_pb2.Controller.JobQuery | None = None,
34+
page_size: int = 500,
35+
) -> list[job_pb2.JobStatus]:
36+
del page_size
37+
captured = controller_pb2.Controller.JobQuery()
38+
if query is not None:
39+
captured.CopyFrom(query)
40+
self.captured_queries.append(captured)
41+
return list(self.jobs)
42+
43+
def shutdown(self, wait: bool = True) -> None:
44+
del wait
45+
46+
def __getattr__(self, name: str) -> Any:
47+
raise AttributeError(name)
48+
49+
50+
def _make_status(job_id: str, state: job_pb2.JobState) -> job_pb2.JobStatus:
51+
status = job_pb2.JobStatus()
52+
status.job_id = job_id
53+
status.state = state
54+
return status
55+
56+
57+
@pytest.fixture
58+
def stub_client() -> tuple[IrisClient, _RecordingClusterClient]:
59+
stub = _RecordingClusterClient()
60+
client = IrisClient(cluster=stub)
61+
return client, stub
62+
63+
64+
def test_list_jobs_no_filter_sends_empty_query(stub_client):
65+
client, stub = stub_client
66+
stub.jobs = [_make_status("/u/a", job_pb2.JOB_STATE_RUNNING)]
67+
68+
client.list_jobs()
69+
70+
assert len(stub.captured_queries) == 1
71+
q = stub.captured_queries[0]
72+
assert q.state_filter == ""
73+
assert q.name_filter == ""
74+
75+
76+
def test_list_jobs_state_is_pushed_down(stub_client):
77+
client, stub = stub_client
78+
stub.jobs = [_make_status("/u/a", job_pb2.JOB_STATE_RUNNING)]
79+
80+
client.list_jobs(state=job_pb2.JOB_STATE_RUNNING)
81+
82+
assert stub.captured_queries[0].state_filter == "running"
83+
84+
85+
def test_list_jobs_prefix_is_pushed_down(stub_client):
86+
client, stub = stub_client
87+
prefix = JobName.root("alice", "exp")
88+
stub.jobs = [_make_status(prefix.to_wire() + "-1", job_pb2.JOB_STATE_PENDING)]
89+
90+
client.list_jobs(prefix=prefix)
91+
92+
assert stub.captured_queries[0].name_filter == prefix.to_wire()
93+
94+
95+
def test_list_jobs_reanchors_prefix_client_side(stub_client):
96+
"""name_filter is a substring; client must still enforce startswith.
97+
98+
`/bob/alice/exp-oops` contains `/alice/exp` as a substring so it
99+
slips through the server-side LIKE '%...%' filter. The client must
100+
drop it because it is not a true prefix of `/alice/exp`.
101+
"""
102+
client, stub = stub_client
103+
prefix = JobName.root("alice", "exp")
104+
wire = prefix.to_wire()
105+
stub.jobs = [
106+
_make_status(wire, job_pb2.JOB_STATE_RUNNING),
107+
_make_status(wire + "-child", job_pb2.JOB_STATE_RUNNING),
108+
_make_status("/bob/alice/exp-oops", job_pb2.JOB_STATE_RUNNING),
109+
]
110+
111+
result = client.list_jobs(prefix=prefix)
112+
113+
assert {j.job_id for j in result} == {wire, wire + "-child"}

0 commit comments

Comments
 (0)