Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 3 additions & 15 deletions lib/iris/src/iris/client/worker_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,6 @@ def __init__(
self._timeout = timeout
self._discover_backoff = ExponentialBackoff(initial=0.05, maximum=1.0)
self._actor_client: ActorClient | None = None
self._stop_event: threading.Event | None = None

def make_target(self) -> Callable[..., None]:
"""Create a thread target that carries the current context.
Expand All @@ -211,27 +210,19 @@ def target(stop_event: threading.Event) -> None:
return target

def _run(self, stop_event: threading.Event) -> None:
self._stop_event = stop_event
while not stop_event.is_set():
if self.state.status == WorkerStatus.PENDING:
self._discover_endpoint()
self._discover_endpoint(stop_event)
continue

if self.state.status == WorkerStatus.FAILED:
break

if self._actor_client is None:
self._actor_client = ActorClient(
resolver=self._resolver,
name=self.state.worker_name,
call_timeout=self._timeout,
)

task = self._get_task()
if task:
self._execute_task(task)

def _discover_endpoint(self) -> None:
def _discover_endpoint(self, stop_event: threading.Event) -> None:
logger.debug(
"Discovering endpoint for worker %s (name=%s)",
self.state.worker_id,
Expand All @@ -252,10 +243,7 @@ def _discover_endpoint(self) -> None:
logger.info("Worker %s discovered at %s", self.state.worker_id, endpoint.url)
else:
logger.debug("Worker %s not found, waiting...", self.state.worker_id)
if self._stop_event:
self._stop_event.wait(self._discover_backoff.next_interval())
else:
time.sleep(self._discover_backoff.next_interval())
stop_event.wait(self._discover_backoff.next_interval())

def _get_task(self) -> PendingTask | None:
"""Try to get a task from the queue."""
Expand Down
5 changes: 3 additions & 2 deletions lib/levanter/src/levanter/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ def value_at_step(schedule_or_t: Sequence[ScheduleStep[T]] | T, step: int) -> T:
if not isinstance(schedule_or_t, Sequence) or (schedule_or_t and not isinstance(schedule_or_t[0], ScheduleStep)):
return schedule_or_t # type: ignore

for i, step_ in enumerate(schedule_or_t):
# we use start now
# Iterate in reverse to find the last segment whose start is <= step.
# A forward loop would always stop at the first segment (typically start=0).
for step_ in reversed(schedule_or_t):
if step >= step_.start:
return step_.value

Expand Down
35 changes: 34 additions & 1 deletion lib/levanter/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest

from levanter.schedule import BatchSchedule, ScheduleStep
from levanter.schedule import BatchSchedule, ScheduleStep, value_at_step


@pytest.fixture
Expand Down Expand Up @@ -59,3 +59,36 @@ def test_batch_scheduler(scheduler, step, expected_bs, expected_offset, expected
assert bs == expected_bs, f"Unexpected batch size at step {step}"
assert offset == expected_offset, f"Unexpected data offset at step {step}"
# assert indices == expected_indices, f"Unexpected batch indices at step {step}"


def test_value_at_step_scalar():
assert value_at_step(42, 0) == 42
assert value_at_step(42, 1000) == 42


@pytest.mark.parametrize(
"step, expected",
[
(0, 32),
(500, 32),
(999, 32),
(1000, 64),
(50000, 64),
(99999, 64),
(100000, 128),
(250000, 128),
],
)
def test_value_at_step_schedule(step, expected):
schedule = [
ScheduleStep(start=0, value=32),
ScheduleStep(start=1000, value=64),
ScheduleStep(start=100000, value=128),
]
assert value_at_step(schedule, step) == expected


def test_value_at_step_before_first_segment_raises():
schedule = [ScheduleStep(start=100, value="a")]
with pytest.raises(ValueError):
value_at_step(schedule, 50)
71 changes: 0 additions & 71 deletions lib/marin/src/marin/core/runtime.py

This file was deleted.

Loading