Skip to content

[Serve] Optimize stop_replicas() to avoid pop-all/re-add cycle#60832

Open
abrarsheikh wants to merge 2 commits intomasterfrom
optimize-stop-replicas
Open

[Serve] Optimize stop_replicas() to avoid pop-all/re-add cycle#60832
abrarsheikh wants to merge 2 commits intomasterfrom
optimize-stop-replicas

Conversation

@abrarsheikh
Copy link
Contributor

@abrarsheikh abrarsheikh commented Feb 7, 2026

stop_replicas() pops every replica across all 7 states, checks set membership, and re-adds the vast majority back. Each re-add triggers update_actor_details() which rebuilds a ReplicaDetails pydantic object. When stopping 2 out of 4096 replicas, 4094 replicas get needlessly popped, rebuilt, and re-added.

Fix

Add a remove(replica_ids) method to ReplicaStateContainer that performs a single O(N) pass with O(1) set lookups. Non-matching replicas stay in place — no re-add, no update_state call. Early-exits once all targets are found, and only rebuilds the list for states where a match was found.

Benchmark results

Benchmark script - AI
"""Micro-benchmark: stop_replicas() pop-all vs selective-remove.

Measures latency and peak memory when stopping a small fraction of replicas
from a ReplicaStateContainer, comparing the old approach (pop all + re-add)
against the new approach (selective remove by ID).

Usage:
    python bench_stop_replicas.py
"""

import gc
import math
import random
import statistics
import time
import tracemalloc
from collections import defaultdict
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, List, Optional, Set


# ---------------------------------------------------------------------------
# Lightweight stubs – just enough to exercise ReplicaStateContainer logic
# without importing all of Ray Serve.
# ---------------------------------------------------------------------------


@dataclass(frozen=True)
class DeploymentID:
    name: str
    app_name: str = "default"

    def __hash__(self):
        return hash((self.name, self.app_name))


@dataclass(frozen=True)
class ReplicaID:
    unique_id: str
    deployment_id: DeploymentID

    def __hash__(self):
        return hash((self.unique_id, self.deployment_id))


class ReplicaState(str, Enum):
    STARTING = "STARTING"
    UPDATING = "UPDATING"
    RECOVERING = "RECOVERING"
    RUNNING = "RUNNING"
    STOPPING = "STOPPING"
    PENDING_MIGRATION = "PENDING_MIGRATION"


ALL_REPLICA_STATES = list(ReplicaState)


class _StubActorDetails:
    """Minimal stand-in for ReplicaDetails (pydantic model in production).

    Deliberately uses __slots__ to keep memory footprint realistic but light.
    """

    __slots__ = ("state", "replica_id", "node_id")

    def __init__(self, state: ReplicaState, replica_id: str):
        self.state = state
        self.replica_id = replica_id
        self.node_id = "node-0"

    def dict(self):
        return {
            "state": self.state,
            "replica_id": self.replica_id,
            "node_id": self.node_id,
        }


class FakeReplica:
    """Minimal stand-in for DeploymentReplica.

    Implements only the surface area touched by the container and stop_replicas.
    """

    def __init__(self, replica_id: ReplicaID, state: ReplicaState):
        self._replica_id = replica_id
        self._actor_details = _StubActorDetails(state, replica_id.unique_id)
        self._update_state_calls = 0

    @property
    def replica_id(self) -> ReplicaID:
        return self._replica_id

    @property
    def actor_details(self):
        return self._actor_details

    def update_state(self, state: ReplicaState) -> None:
        """Mirrors update_actor_details; rebuilds the details object."""
        self._actor_details = _StubActorDetails(state, self._replica_id.unique_id)
        self._update_state_calls += 1


# ---------------------------------------------------------------------------
# Container implementations
# ---------------------------------------------------------------------------


class _ContainerBase:
    """Shared helpers."""

    def __init__(self):
        self._replicas: Dict[ReplicaState, List[FakeReplica]] = defaultdict(list)

    def add(self, state: ReplicaState, replica: FakeReplica):
        replica.update_state(state)
        self._replicas[state].append(replica)

    def pop(
        self,
        exclude_version=None,
        states=None,
        max_replicas=math.inf,
    ) -> List[FakeReplica]:
        if states is None:
            states = ALL_REPLICA_STATES
        replicas = []
        for state in states:
            popped = []
            remaining = []
            for replica in self._replicas[state]:
                if len(replicas) + len(popped) == max_replicas:
                    remaining.append(replica)
                else:
                    popped.append(replica)
            self._replicas[state] = remaining
            replicas.extend(popped)
        return replicas


class OldContainer(_ContainerBase):
    """Original stop_replicas: pop everything, re-add non-matching."""

    def stop_replicas(self, replicas_to_stop: Set[ReplicaID]):
        stopped = []
        for replica in self.pop():
            if replica.replica_id in replicas_to_stop:
                # In production this calls _stop_replica(); we just record it.
                stopped.append(replica)
            else:
                self.add(replica.actor_details.state, replica)
        return stopped


class NewContainer(_ContainerBase):
    """New stop_replicas: single-pass remove_many by ID set."""

    def remove(self, replica_ids) -> List[FakeReplica]:
        replica_ids = set(replica_ids)
        removed = []
        remaining_to_find = len(replica_ids)
        for state in ALL_REPLICA_STATES:
            if remaining_to_find == 0:
                break
            found_any = False
            remaining = []
            for replica in self._replicas[state]:
                if remaining_to_find > 0 and replica.replica_id in replica_ids:
                    removed.append(replica)
                    remaining_to_find -= 1
                    found_any = True
                else:
                    remaining.append(replica)
            if found_any:
                self._replicas[state] = remaining
        return removed

    def stop_replicas(self, replicas_to_stop: Set[ReplicaID]):
        return self.remove(replicas_to_stop)


# ---------------------------------------------------------------------------
# Benchmark harness
# ---------------------------------------------------------------------------

DEPLOYMENT_ID = DeploymentID("bench-deploy", "bench-app")
WARMUP_ROUNDS = 3
MEASURE_ROUNDS = 20


def _make_replicas(n: int) -> List[FakeReplica]:
    return [
        FakeReplica(
            ReplicaID(f"r-{i}", DEPLOYMENT_ID),
            ReplicaState.RUNNING,
        )
        for i in range(n)
    ]


def _fill_container(container, replicas: List[FakeReplica]):
    for r in replicas:
        container.add(ReplicaState.RUNNING, r)


def _pick_targets(replicas: List[FakeReplica], k: int) -> Set[ReplicaID]:
    chosen = random.sample(replicas, k)
    return {r.replica_id for r in chosen}


def bench_latency(container_cls, n_replicas: int, num_to_stop: int) -> dict:
    """Returns dict with median, p99, mean latency in microseconds."""
    replicas_master = _make_replicas(n_replicas)
    targets = _pick_targets(replicas_master, num_to_stop)

    timings = []

    for rnd in range(WARMUP_ROUNDS + MEASURE_ROUNDS):
        c = container_cls()
        _fill_container(c, replicas_master)

        gc.disable()
        t0 = time.perf_counter_ns()
        c.stop_replicas(targets)
        t1 = time.perf_counter_ns()
        gc.enable()

        if rnd >= WARMUP_ROUNDS:
            timings.append((t1 - t0) / 1_000)  # ns -> µs

    timings.sort()
    return {
        "median_us": statistics.median(timings),
        "p99_us": timings[int(len(timings) * 0.99)],
        "mean_us": statistics.mean(timings),
    }


def bench_memory(container_cls, n_replicas: int, num_to_stop: int) -> dict:
    """Returns dict with peak memory allocation delta in KiB.

    Averages over multiple runs to reduce noise.
    """
    NUM_MEM_RUNS = 5
    deltas = []
    for _ in range(NUM_MEM_RUNS):
        replicas_master = _make_replicas(n_replicas)
        targets = _pick_targets(replicas_master, num_to_stop)

        c = container_cls()
        _fill_container(c, replicas_master)

        gc.collect()
        tracemalloc.start()

        snap_before = tracemalloc.take_snapshot()
        c.stop_replicas(targets)
        snap_after = tracemalloc.take_snapshot()

        tracemalloc.stop()

        stats = snap_after.compare_to(snap_before, "lineno")
        delta_bytes = sum(s.size_diff for s in stats if s.size_diff > 0)
        deltas.append(delta_bytes / 1024)

    return {"peak_delta_kib": statistics.median(deltas)}


def bench_update_state_calls(n_replicas: int, num_to_stop: int) -> dict:
    """Returns update_state call counts for old vs new."""
    replicas_old = _make_replicas(n_replicas)
    replicas_new = _make_replicas(n_replicas)
    targets = _pick_targets(replicas_old, num_to_stop)

    c_old = OldContainer()
    _fill_container(c_old, replicas_old)
    for r in replicas_old:
        r._update_state_calls = 0

    c_new = NewContainer()
    _fill_container(c_new, replicas_new)
    for r in replicas_new:
        r._update_state_calls = 0

    c_old.stop_replicas(targets)
    c_new.stop_replicas(targets)

    old_calls = sum(r._update_state_calls for r in replicas_old)
    new_calls = sum(r._update_state_calls for r in replicas_new)
    return {"old": old_calls, "new": new_calls}


def run_scenario(label: str, replica_counts: List[int], stop_fn):
    """Run a full latency + memory + update_state scenario.

    Args:
        label: human-readable description (e.g. "stopping 2 replicas").
        replica_counts: list of total replica counts to sweep.
        stop_fn: callable(n) -> number of replicas to stop for that n.
    """
    hdr = (
        f"{'Replicas':>8} {'k':>5}"
        f"  │ {'Old µs':>9} {'New µs':>9} {'Speedup':>7}"
        f"  │ {'Old KiB':>8} {'New KiB':>8} {'Saved':>8}"
        f"  │ {'Old upd':>8} {'New upd':>8}"
    )
    sep = "─" * len(hdr)

    print()
    print(f"  {label}")
    print(sep)
    print(hdr)
    print(sep)

    for n in replica_counts:
        k = stop_fn(n)
        old_lat = bench_latency(OldContainer, n, k)
        new_lat = bench_latency(NewContainer, n, k)
        speedup = old_lat["median_us"] / new_lat["median_us"] if new_lat["median_us"] > 0 else float("inf")
        old_mem = bench_memory(OldContainer, n, k)
        new_mem = bench_memory(NewContainer, n, k)
        saved = old_mem["peak_delta_kib"] - new_mem["peak_delta_kib"]
        us = bench_update_state_calls(n, k)

        print(
            f"{n:>8} {k:>5}"
            f"  │ {old_lat['median_us']:>8.1f}µ {new_lat['median_us']:>8.1f}µ {speedup:>6.1f}x"
            f"  │ {old_mem['peak_delta_kib']:>7.1f}  {new_mem['peak_delta_kib']:>7.1f} {saved:>7.1f}"
            f"  │ {us['old']:>8,} {us['new']:>8,}"
        )

    print(sep)


def main():
    replica_counts = [16, 32, 64, 128, 256, 512, 1024, 2048, 4096]

    print()
    print("=" * 72)
    print("  stop_replicas() micro-benchmark")
    print("  Old = pop all + re-add non-matching  (original)")
    print("  New = single-pass remove by ID set   (optimized)")
    print("=" * 72)

    # Scenario 1: stop 2 replicas (typical downscale)
    run_scenario(
        "SCENARIO 1: Downscale — stop 2 out of N  (typical)",
        replica_counts,
        stop_fn=lambda n: 2,
    )

    # Scenario 2: stop all replicas (full teardown)
    run_scenario(
        "SCENARIO 2: Teardown — stop ALL N replicas  (worst case for old per-ID approach)",
        replica_counts,
        stop_fn=lambda n: n,
    )

    # Scenario 3: stop 10% of replicas
    run_scenario(
        "SCENARIO 3: Moderate downscale — stop 10% of N",
        [64, 128, 256, 512, 1024, 2048, 4096],
        stop_fn=lambda n: max(1, n // 10),
    )

    print()


if __name__ == "__main__":
    main()

Scenario 1 — Downscale: stop 2 out of N (typical)

Replicas k Old µs New µs Speedup Old KiB New KiB Saved KiB Old upd_state New upd_state
64 2 53.6 8.4 6.4x 5.1 1.2 3.8 62 0
256 2 211.3 62.4 3.4x 17.2 2.8 14.3 254 0
1024 2 852.2 295.2 2.9x 65.7 9.3 56.3 1,022 0
4096 2 3402.1 586.7 5.8x 257.3 32.9 224.3 4,094 0

Scenario 2 — Teardown: stop ALL N (no regression)

Replicas k Old µs New µs Speedup Old KiB New KiB Saved KiB
256 256 98.8 84.6 1.2x 1.2 0.7 0.5
1024 1024 402.0 340.6 1.2x 1.2 0.7 0.5
4096 4096 1635.7 1378.4 1.2x 1.2 0.7 0.5

Memory is flat and nearly identical for both — when stopping everything, neither approach re-adds anything.

Scenario 3 — Moderate downscale: stop 10% of N

Replicas k Old µs New µs Speedup Old KiB New KiB Saved KiB Old upd_state New upd_state
256 25 205.3 75.9 2.7x 15.6 2.5 13.1 231 0
1024 102 828.2 319.5 2.6x 59.2 8.3 50.9 922 0
4096 409 3301.4 1299.3 2.5x 235.0 32.9 202.1 3,687 0

The old code's memory grows linearly with N because pop() allocates a temporary list of all replicas, then add() rebuilds ReplicaDetails for each re-inserted one. The new code only allocates a remaining list for states where a match is found, and never touches non-matching replicas.

related to #60680

Signed-off-by: abrar <abrar@anyscale.com>
@abrarsheikh abrarsheikh requested a review from a team as a code owner February 7, 2026 07:53
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request optimizes stop_replicas to avoid a costly cycle of popping and re-adding all replicas when only a few need to be stopped. This is achieved by introducing a remove(replica_id) method in ReplicaStateContainer and using it in a loop within stop_replicas. The benchmarks clearly show significant performance improvements for stopping a small number of replicas.

My main concern, detailed in a specific comment, is a potential performance regression when stopping a large number of replicas, due to the O(k*N) complexity of the new approach. I've suggested an alternative that would be efficient for all cases.

Signed-off-by: abrar <abrar@anyscale.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant