Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,6 @@ test:
export HF_HUB_TOKEN=$HF_TOKEN
RAY_ADDRESS= PYTHONPATH=tests:. pytest tests --durations=0 -n 4 --tb=no -v

# Target to configure GCP registry cleanup policy for all standard regions
CLUSTER_REPOS = us-central2 us-central1 europe-west4 us-west4 us-east5 us-east1
default_registry_name = marin
configure_gcp_registry_all:
@echo "Configuring GCP registry cleanup policy for all standard regions..."
$(foreach region,$(CLUSTER_REPOS), \
python infra/configure_gcp_registry.py $(default_registry_name) --region=$(region) ; \
)
@echo "Cleanup policy configured for all regions."


# stuff for setting up locally
install_uv:
Expand Down
9 changes: 8 additions & 1 deletion lib/iris/src/iris/cli/bug_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from iris.cluster.types import JobName
from iris.rpc import controller_pb2, job_pb2
from iris.rpc.auth import AuthTokenInjector, TokenProvider
from iris.rpc.compression import IRIS_RPC_COMPRESSIONS
from iris.rpc.controller_connect import ControllerServiceClientSync
from iris.rpc.proto_utils import format_resources, job_state_friendly, task_state_friendly
from iris.time_proto import timestamp_from_proto
Expand Down Expand Up @@ -119,7 +120,13 @@ def gather_bug_report(
) -> BugReport:
"""Gather all diagnostic data for a job into a BugReport."""
interceptors = [AuthTokenInjector(token_provider)] if token_provider else []
client = ControllerServiceClientSync(controller_url, timeout_ms=30000, interceptors=interceptors)
client = ControllerServiceClientSync(
controller_url,
timeout_ms=30000,
interceptors=interceptors,
accept_compression=IRIS_RPC_COMPRESSIONS,
send_compression=None,
)
log_client = LogClient.connect(controller_url, timeout_ms=30000, interceptors=interceptors)
try:
return _gather(client, log_client, job_id, tail=tail)
Expand Down
14 changes: 4 additions & 10 deletions lib/iris/src/iris/cli/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
from iris.rpc.auth import TokenProvider
from iris.rpc.proto_utils import (
PRIORITY_BAND_NAMES,
format_resources,
job_state_friendly,
priority_band_value,
task_state_friendly,
Expand Down Expand Up @@ -1008,31 +1007,26 @@ def list_jobs(ctx, state: str | None, prefix: str | None, json_output: bool) ->
click.echo("No jobs found.")
return

# Build table rows
rows: list[list[str]] = []
has_reasons = False

for j in jobs:
job_id = j.job_id
state_name = job_state_friendly(j.state)
submitted = timestamp_from_proto(j.submitted_at).as_formatted_date() if j.submitted_at.epoch_ms else "-"
resources = format_resources(j.resources) if j.HasField("resources") else "-"

# Show error for failed jobs, pending_reason for pending/unschedulable
reason = j.error or j.pending_reason or ""
if reason:
has_reasons = True
# Truncate long reasons
reason = (reason[:60] + "...") if len(reason) > 63 else reason

rows.append([job_id, state_name, resources, submitted, reason])
rows.append([job_id, state_name, submitted, reason])

# Build headers - only include REASON column if there are any reasons
if has_reasons:
headers = ["JOB ID", "STATE", "RESOURCES", "SUBMITTED", "REASON"]
headers = ["JOB ID", "STATE", "SUBMITTED", "REASON"]
else:
headers = ["JOB ID", "STATE", "RESOURCES", "SUBMITTED"]
rows = [row[:4] for row in rows]
headers = ["JOB ID", "STATE", "SUBMITTED"]
rows = [row[:3] for row in rows]

click.echo(tabulate(rows, headers=headers, tablefmt="plain"))

Expand Down
9 changes: 8 additions & 1 deletion lib/iris/src/iris/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from iris.rpc import config_pb2, job_pb2
from iris.rpc import controller_pb2 as _controller_pb2
from iris.rpc.auth import AuthTokenInjector, GcpAccessTokenProvider, StaticTokenProvider, TokenProvider
from iris.rpc.compression import IRIS_RPC_COMPRESSIONS
from iris.rpc.controller_connect import ControllerServiceClientSync
from iris.rpc.proto_utils import PRIORITY_BAND_NAMES, priority_band_name, priority_band_value

Expand Down Expand Up @@ -124,7 +125,13 @@ def rpc_client(
) -> ControllerServiceClientSync:
"""Create an RPC client with optional auth. Use as a context manager: ``with rpc_client(url) as c:``."""
interceptors = [AuthTokenInjector(token_provider)] if token_provider else []
return ControllerServiceClientSync(address, timeout_ms=timeout_ms, interceptors=interceptors)
return ControllerServiceClientSync(
address,
timeout_ms=timeout_ms,
interceptors=interceptors,
accept_compression=IRIS_RPC_COMPRESSIONS,
send_compression=None,
)


def require_controller_url(ctx: click.Context) -> str:
Expand Down
3 changes: 3 additions & 0 deletions lib/iris/src/iris/client/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from iris.actor.resolver import ResolvedEndpoint, ResolveResult
from iris.cluster.types import Namespace
from iris.rpc import controller_pb2
from iris.rpc.compression import IRIS_RPC_COMPRESSIONS
from iris.rpc.controller_connect import ControllerServiceClientSync


Expand Down Expand Up @@ -54,6 +55,8 @@ def __init__(
self._client = ControllerServiceClientSync(
address=self._address,
timeout_ms=int(timeout * 1000),
accept_compression=IRIS_RPC_COMPRESSIONS,
send_compression=None,
)

def _namespace_prefix(self) -> str:
Expand Down
11 changes: 6 additions & 5 deletions lib/iris/src/iris/cluster/client/remote_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from iris.cluster.runtime.entrypoint import build_runtime_entrypoint
from iris.cluster.types import Entrypoint, EnvironmentSpec, JobName, TaskAttempt, adjust_tpu_replicas, is_job_finished
from iris.rpc import controller_pb2, job_pb2
from iris.rpc.compression import IRIS_RPC_COMPRESSIONS
from iris.rpc.controller_connect import ControllerServiceClientSync
from iris.rpc.errors import call_with_retry, format_connect_error, poll_with_retries
from iris.time_proto import duration_to_proto
Expand All @@ -35,14 +36,12 @@

# Upper bound on GetJobState polling cadence for long-running jobs. The loop
# ramps 100ms -> 1s within a handful of polls (factor=1.5 in ExponentialBackoff)
# and then caps here, so long jobs cost ~1 state RPC / 30s instead of hammering
# the controller at the old ~2s ceiling.
# and then caps here, so long jobs cost ~1 state RPC / 30s.
MAX_STATE_POLL_INTERVAL = 30.0

# Floor on the backoff cap. ``ExponentialBackoff`` requires ``maximum >= initial``
# (currently 100ms), so we clamp the caller-supplied ``poll_interval`` up to this
# value before handing it to the backoff. Callers asking for a sub-100ms cap end
# up polling at 100ms instead of crashing with ValueError.
# (currently 100ms), so callers asking for a sub-100ms cap are clamped to this
# value before being handed to the backoff.
MIN_STATE_POLL_INTERVAL = 0.1


Expand Down Expand Up @@ -78,6 +77,8 @@ def __init__(
address=controller_address,
timeout_ms=timeout_ms,
interceptors=interceptors,
accept_compression=IRIS_RPC_COMPRESSIONS,
send_compression=None,
)
self._log_client = LogClient.connect(
controller_address,
Expand Down
4 changes: 3 additions & 1 deletion lib/iris/src/iris/cluster/controller/autoscaler/recovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,10 @@ def load_autoscaler_checkpoint(db: ControllerDB) -> AutoscalerCheckpoint:
"last_active_ms": decode_timestamp_ms,
},
)
# Failed workers have their DB row deleted (WorkerStore.remove), so
# surviving rows with a slice are by definition the live tracked set.
tracked_rows = snapshot.raw(
"SELECT worker_id, slice_id, scale_group, address FROM workers WHERE slice_id != '' AND active = 1",
"SELECT worker_id, slice_id, scale_group, address FROM workers WHERE slice_id != ''",
)

slices_by_group: dict[str, list[SliceSnapshot]] = {}
Expand Down
Loading
Loading