Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
516 changes: 18 additions & 498 deletions .pyrefly-baseline.json

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion lib/fray/src/fray/v1/cluster/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@ def create_cluster(cluster_spec: str) -> Cluster:
"JobRequest",
"LocalCluster",
"ResourceConfig",
"TPUConfig",
"TpuConfig",
"TpuType",
"create_cluster",
Expand Down
8 changes: 4 additions & 4 deletions lib/fray/src/fray/v1/cluster/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from collections.abc import Callable, Sequence
from dataclasses import dataclass, field
from enum import StrEnum
from typing import Any, Literal, NewType, Self
from typing import Any, Literal, NewType

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -436,11 +436,11 @@ def from_callable(
c: Callable[..., Any],
args: Sequence[Any] = (),
kwargs: dict[str, Any] | None = None,
) -> Self:
) -> Entrypoint:
return Entrypoint(callable_entrypoint=CallableEntrypoint(callable=c, args=args, kwargs=kwargs or {}))

@staticmethod
def from_binary(command: str, args: Sequence[str]) -> Self:
def from_binary(command: str, args: Sequence[str]) -> Entrypoint:
return Entrypoint(binary_entrypoint=BinaryEntrypoint(command=command, args=args))


Expand Down Expand Up @@ -479,7 +479,7 @@ class JobStatus(StrEnum):
STOPPED = "stopped"

@staticmethod
def finished(status: Self) -> bool:
def finished(status: JobStatus) -> bool:
return status in (JobStatus.SUCCEEDED, JobStatus.FAILED, JobStatus.STOPPED)


Expand Down
3 changes: 0 additions & 3 deletions lib/fray/src/fray/v1/cluster/ray/tpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,11 @@
__all__ = [
"HEALTH_CHECK_TIMEOUT",
"START_ACTOR_TIMEOUT",
"TPU_CONFIGS",
"MultisliceInfo",
"ResourcePoolManager",
"SliceActor",
"SliceInfo",
"SlicePoolManager",
"TPUConfig",
"TPUHostActor",
"TPUHostInfo",
"TpuCancelled",
Expand All @@ -42,7 +40,6 @@
"TpuRunError",
"TpuSuccess",
"get_current_tpu_is_preempted",
"get_tpu_config",
"run_on_pod",
"run_on_pod_multislice",
"run_on_pod_ray",
Expand Down
25 changes: 12 additions & 13 deletions lib/fray/src/fray/v1/queue/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from typing import Generic, Protocol, TypeVar

T = TypeVar("T")
T_co = TypeVar("T_co", covariant=True)


@dataclass
Expand All @@ -29,39 +28,39 @@ class Lease(Generic[T]):
timestamp: float


class Queue(Protocol[T_co]):
class Queue(Protocol[T]):
"""Distributed queue interface with lease-based task acquisition."""

def push(self, item: T_co) -> None:
def push(self, item: T) -> None:
"""Add an item to the queue."""
...

def peek(self) -> T_co | None:
def peek(self) -> T | None:
"""View the next available item without acquiring a lease."""
...

def pop(self, lease_timeout: float = 60.0) -> Lease[T_co] | None:
def pop(self, lease_timeout: float = 60.0) -> Lease[T] | None:
"""Acquire a lease on the next available item."""
...

def done(self, lease: Lease[T_co]) -> None:
def done(self, lease: Lease[T]) -> None:
"""Mark a leased task as successfully completed."""
...

def release(self, lease: Lease[T_co]) -> None:
def release(self, lease: Lease[T]) -> None:
"""Release a lease and requeue the item for reprocessing."""
...


class MemoryQueue(Queue[T_co]):
class MemoryQueue(Queue[T]):
def __init__(self):
self.queue = []
self.leases = {} # lease_id -> (item, timestamp, timeout)

def push(self, item: T_co) -> None:
def push(self, item: T) -> None:
self.queue.append(item)

def peek(self) -> T_co | None:
def peek(self) -> T | None:
self._recover_expired_leases()
if self.queue:
return self.queue[0]
Expand All @@ -80,7 +79,7 @@ def _recover_expired_leases(self) -> None:
self.queue.insert(0, item)
del self.leases[lease_id]

def pop(self, lease_timeout: float = 60.0) -> Lease[T_co] | None:
def pop(self, lease_timeout: float = 60.0) -> Lease[T] | None:
self._recover_expired_leases()
if self.queue:
item = self.queue.pop(0)
Expand All @@ -91,11 +90,11 @@ def pop(self, lease_timeout: float = 60.0) -> Lease[T_co] | None:
return lease
return None

def done(self, lease: Lease[T_co]) -> None:
def done(self, lease: Lease[T]) -> None:
if lease.lease_id in self.leases:
del self.leases[lease.lease_id]

def release(self, lease: Lease[T_co]) -> None:
def release(self, lease: Lease[T]) -> None:
if lease.lease_id in self.leases:
item, _, _ = self.leases[lease.lease_id]
self.queue.insert(0, item)
Expand Down
6 changes: 3 additions & 3 deletions lib/fray/src/fray/v2/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from collections.abc import Callable, Sequence
from dataclasses import dataclass, field
from enum import StrEnum
from typing import Any, Literal, Self
from typing import Any, Literal

# ---------------------------------------------------------------------------
# TPU topology
Expand Down Expand Up @@ -509,11 +509,11 @@ def from_callable(
c: Callable[..., Any],
args: Sequence[Any] = (),
kwargs: dict[str, Any] | None = None,
) -> Self:
) -> Entrypoint:
return Entrypoint(callable_entrypoint=CallableEntrypoint(callable=c, args=args, kwargs=kwargs or {}))

@staticmethod
def from_binary(command: str, args: Sequence[str]) -> Self:
def from_binary(command: str, args: Sequence[str]) -> Entrypoint:
return Entrypoint(binary_entrypoint=BinaryEntrypoint(command=command, args=args))


Expand Down
3 changes: 2 additions & 1 deletion lib/haliax/src/haliax/jax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,8 @@ def _deshape(x):


def to_jax_shape(shape):
from haliax.core import Axis, ensure_tuple
from haliax.core import Axis
from haliax.util import ensure_tuple

shape = ensure_tuple(shape)
return tuple(axis.size if isinstance(axis, Axis) else axis for axis in shape)
3 changes: 2 additions & 1 deletion lib/iris/src/iris/cluster/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from google.protobuf.json_format import MessageToDict, ParseDict

from iris.cluster.constraints import WellKnownAttribute
from iris.cluster.controller.db import ControllerDB
from iris.cluster.providers.k8s.tasks import K8sTaskProvider
from iris.cluster.providers.protocols import WorkerInfraProvider
from iris.cluster.controller.worker_provider import WorkerProvider
Expand Down Expand Up @@ -1215,7 +1216,7 @@ def create_autoscaler(
label_prefix: str,
base_worker_config: config_pb2.WorkerConfig | None = None,
threads: ThreadContainer | None = None,
db: "ControllerDB | None" = None, # noqa: F821, UP037 — circular import
db: ControllerDB | None = None,
):
"""Create autoscaler from WorkerInfraProvider and explicit config.

Expand Down
15 changes: 9 additions & 6 deletions lib/iris/src/iris/cluster/controller/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,7 @@ def _tasks_by_ids_with_attempts(queries: ControllerDB, task_ids: set[JobName]) -
return {task.task_id: task for task in tasks_with_attempts(tasks, attempts)}


def _building_counts(queries: ControllerDB, workers: list[WorkerRow]) -> dict[WorkerId, int]:
def _building_counts(queries: ControllerDB, workers: list[WorkerSnapshot]) -> dict[WorkerId, int]:
"""Count tasks in BUILDING or ASSIGNED state per worker, excluding reservation-holder jobs."""
if not workers:
return {}
Expand Down Expand Up @@ -672,9 +672,9 @@ def _worker_matches_reservation_entry(


def _inject_reservation_taints(
workers: list[WorkerRow],
workers: list[WorkerSnapshot],
claims: dict[WorkerId, ReservationClaim],
) -> list[WorkerRow]:
) -> list[WorkerSnapshot]:
"""Create modified worker copies with reservation taints and prioritization.

Claimed workers receive a ``reservation-job`` attribute set to the claiming
Expand All @@ -687,8 +687,8 @@ def _inject_reservation_taints(
if not claims:
return workers

claimed: list[WorkerRow] = []
unclaimed: list[WorkerRow] = []
claimed: list[WorkerSnapshot] = []
unclaimed: list[WorkerSnapshot] = []
for worker in workers:
claim = claims.get(worker.worker_id)
if claim is not None:
Expand Down Expand Up @@ -1536,14 +1536,17 @@ def _capture_one_profile(
duration: int,
) -> None:
"""Capture a single task profile via RPC and store it in the DB."""
# Profile loop is only spawned on the non-K8s provider path (see start()).
assert not isinstance(self._provider, K8sTaskProvider)
provider = self._provider
try:
request = job_pb2.ProfileTaskRequest(
target=task_id.to_wire(),
duration_seconds=duration,
profile_type=profile_type,
)
timeout_ms = duration * 1000 + 30000
resp = self._provider.profile_task(worker.address, request, timeout_ms=timeout_ms)
resp = provider.profile_task(worker.address, request, timeout_ms=timeout_ms)
if resp.error:
logger.debug("Profile (%s) failed for %s: %s", profile_kind, task_id, resp.error)
return
Expand Down
7 changes: 5 additions & 2 deletions lib/iris/src/iris/cluster/controller/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@
from dataclasses import dataclass, field, replace as dc_replace
from pathlib import Path
from threading import Lock, RLock
from typing import Any
from typing import TYPE_CHECKING, Any

from iris.cluster.constraints import AttributeValue

if TYPE_CHECKING:
from iris.cluster.controller.endpoint_registry import EndpointRegistry
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Avoid TYPE_CHECKING import guard in controller DB

This introduces a TYPE_CHECKING-guarded import, which violates the explicit Iris convention in lib/iris/AGENTS.md (“Avoid TYPE_CHECKING. Use real imports. If you hit a cycle, prefer refactoring or use a Protocol at the boundary.”). The same rule is also present in /AGENTS.md; please replace this guard with a structural cycle break rather than adding a forbidden import pattern.

Useful? React with 👍 / 👎.

from iris.cluster.controller.schema import decode_timestamp_ms, decode_worker_id
from iris.cluster.types import TERMINAL_TASK_STATES, JobName, WorkerId
from iris.rpc import job_pb2
Expand Down Expand Up @@ -331,7 +334,7 @@ def __init__(self, db_dir: Path):
logger.info("EndpointRegistry initialized in %.2fs", time.monotonic() - t0)

@property
def endpoints(self) -> EndpointRegistry: # noqa: F821
def endpoints(self) -> EndpointRegistry:
"""Process-local cache for the ``endpoints`` table; authoritative for reads."""
return self._endpoint_registry

Expand Down
4 changes: 2 additions & 2 deletions lib/levanter/src/levanter/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import abc
import logging
from typing import Callable, Generic, Optional, Sequence, TypeAlias, TypeVar
from typing import Any, Callable, Generic, Optional, Sequence, TypeAlias, TypeVar

import jax.random
import numpy as np
Expand Down Expand Up @@ -324,7 +324,7 @@ class BatchMappedAsyncDataset(AsyncDataset[U]):

def __init__(
self,
dataset: AsyncDataset[T],
dataset: AsyncDataset[Any],
fn: MapFunction[Sequence[U]],
*extra_args,
**extra_kwargs,
Expand Down
2 changes: 1 addition & 1 deletion lib/levanter/src/levanter/data/sharded_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ class _TransformedDataset:


class _MappedShardedDataSource(ShardedDataSource[T], _TransformedDataset):
def __init__(self, source: ShardedDataSource[T_co], fn: Callable[[T_co], T]):
def __init__(self, source: ShardedDataSource[Any], fn: Callable[[Any], T]):
self.source = source
self.fn = fn
self._transform = _MapTransform(fn)
Expand Down
5 changes: 3 additions & 2 deletions lib/levanter/src/levanter/inference/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,14 +349,15 @@ def _batch_processing_loop(self):
while not self.shutdown_event.is_set():
try:
batch = self.batch_queue.get(timeout=1)
except queue.Empty:
continue
try:
with (
self.model_lock,
hax.partitioning.set_mesh(self.config.trainer.device_mesh),
hax.axis_mapping(self.config.trainer.compute_axis_mapping),
):
self._execute_batch(batch)
except queue.Empty:
continue
except Exception as e:
logger.error(f"Error executing batch: {e}", exc_info=True)
# Set exceptions on all futures in the batch
Expand Down
8 changes: 4 additions & 4 deletions lib/levanter/src/levanter/optim/model_averaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ class EmaModelAveraging(ModelAveraging[M]):
model: M
beta: float = eqx.field(static=True)

def update(self: S, new_model: M, step: int) -> S:
def update(self: S, model: M, step: int) -> S:
del step
# 1 - beta because increment_update expects the weight of the new model
return dataclasses.replace(self, model=optax.incremental_update(new_model, self.model, 1 - self.beta)) # type: ignore
return dataclasses.replace(self, model=optax.incremental_update(model, self.model, 1 - self.beta)) # type: ignore

@property
def model_params(self) -> M:
Expand Down Expand Up @@ -70,11 +70,11 @@ def _raw_weight(self, step: int) -> float:
frac = jnp.clip(t / self.decay_steps, 0.0, 1.0)
return float(1.0 - jnp.sqrt(frac))

def update(self, new_model: M, step: int) -> "EmaDecaySqrtModelAveraging[M]":
def update(self, model: M, step: int) -> "EmaDecaySqrtModelAveraging[M]":
w = self._raw_weight(step)
new_tot_w = self.tot_weight + w
alpha = 0.0 if new_tot_w == 0.0 else w / new_tot_w
updated = optax.incremental_update(new_model, self.model, alpha)
updated = optax.incremental_update(model, self.model, alpha)
return dataclasses.replace(self, model=updated, tot_weight=new_tot_w) # type: ignore[arg-type]

@property
Expand Down
4 changes: 2 additions & 2 deletions lib/levanter/src/levanter/store/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import threading
import time
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Sequence, TypeVar, Union
from typing import Any, Dict, Generic, List, Optional, Sequence, TypeVar, Union

import deepdiff
import jax
Expand Down Expand Up @@ -217,7 +217,7 @@ def empty():
return CacheMetadata()


class SerialCacheWriter:
class SerialCacheWriter(Generic[T]):
"""
Writes TreeCache-compatible caches to disk without Ray. Mostly for scripts and debugging.
"""
Expand Down
3 changes: 2 additions & 1 deletion lib/levanter/src/levanter/utils/py_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ def __init__(self, *args, bytes_strategy="base64", **kwargs):
super().__init__(*args, **kwargs)
self.bytes_strategy = bytes_strategy

def default(self, obj):
def default(self, o):
obj = o
# Known clean conversions
if isinstance(obj, (datetime.datetime, datetime.date, datetime.time)):
# ISO 8601; preserves tzinfo if present
Expand Down
10 changes: 9 additions & 1 deletion lib/marin/src/marin/cluster/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os
from dataclasses import dataclass
from pathlib import Path
from typing import TypedDict

import jinja2
import yaml
Expand Down Expand Up @@ -314,7 +315,14 @@ def list_available_configs() -> list[str]:
},
}

GENERATION_CONFIGS = {

class _GenerationConfig(TypedDict):
runtime_version: str
base_worker: str
slices: list[int]


GENERATION_CONFIGS: dict[str, _GenerationConfig] = {
"v4": {
"runtime_version": "tpu-ubuntu2204-base",
"base_worker": "8",
Expand Down
Loading
Loading