Skip to content

[Serve] Eliminate per-replica-per-tick Pydantic rebuild in update_act…#60833

Open
abrarsheikh wants to merge 2 commits intomasterfrom
60680-abrar-rebuild_replica
Open

[Serve] Eliminate per-replica-per-tick Pydantic rebuild in update_act…#60833
abrarsheikh wants to merge 2 commits intomasterfrom
60680-abrar-rebuild_replica

Conversation

@abrarsheikh
Copy link
Contributor

@abrarsheikh abrarsheikh commented Feb 7, 2026

DeploymentReplica.update_actor_details() is called on every replica on every controller tick due to the pop-iterate-readd pattern in ReplicaStateContainer.add(). The old implementation ran a full .dict() serialization followed by ReplicaDetails(**kwargs) validation on every call — even when nothing changed (e.g., a RUNNING replica re-added as RUNNING). With Pydantic v1 compat this is especially expensive (~20 µs/replica).

This PR makes two changes:

  1. Early-exit guard: Skip the update entirely when all provided values already match the current model. This short-circuits the dominant hot path (same-state readd) at the cost of a few getattr comparisons.
  2. .copy(update=...) instead of .dict() + reconstruct: When an update is needed, use Pydantic's .copy(update=kwargs) which creates a shallow copy without full serialization or re-validation.

Benchmark results (real ReplicaDetails Pydantic model, 16 → 4096 replicas)

Benchmark Script - AI
"""Micro-benchmark: stop_replicas() pop-all vs selective-remove.

Measures latency and peak memory when stopping a small fraction of replicas
from a ReplicaStateContainer, comparing the old approach (pop all + re-add)
against the new approach (selective remove by ID).

Usage:
    python bench_stop_replicas.py
"""

import gc
import math
import random
import statistics
import time
import tracemalloc
from collections import defaultdict
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, List, Optional, Set


# ---------------------------------------------------------------------------
# Lightweight stubs – just enough to exercise ReplicaStateContainer logic
# without importing all of Ray Serve.
# ---------------------------------------------------------------------------


@dataclass(frozen=True)
class DeploymentID:
    name: str
    app_name: str = "default"

    def __hash__(self):
        return hash((self.name, self.app_name))


@dataclass(frozen=True)
class ReplicaID:
    unique_id: str
    deployment_id: DeploymentID

    def __hash__(self):
        return hash((self.unique_id, self.deployment_id))


class ReplicaState(str, Enum):
    STARTING = "STARTING"
    UPDATING = "UPDATING"
    RECOVERING = "RECOVERING"
    RUNNING = "RUNNING"
    STOPPING = "STOPPING"
    PENDING_MIGRATION = "PENDING_MIGRATION"


ALL_REPLICA_STATES = list(ReplicaState)


class _StubActorDetails:
    """Minimal stand-in for ReplicaDetails (pydantic model in production).

    Deliberately uses __slots__ to keep memory footprint realistic but light.
    """

    __slots__ = ("state", "replica_id", "node_id")

    def __init__(self, state: ReplicaState, replica_id: str):
        self.state = state
        self.replica_id = replica_id
        self.node_id = "node-0"

    def dict(self):
        return {
            "state": self.state,
            "replica_id": self.replica_id,
            "node_id": self.node_id,
        }


class FakeReplica:
    """Minimal stand-in for DeploymentReplica.

    Implements only the surface area touched by the container and stop_replicas.
    """

    def __init__(self, replica_id: ReplicaID, state: ReplicaState):
        self._replica_id = replica_id
        self._actor_details = _StubActorDetails(state, replica_id.unique_id)
        self._update_state_calls = 0

    @property
    def replica_id(self) -> ReplicaID:
        return self._replica_id

    @property
    def actor_details(self):
        return self._actor_details

    def update_state(self, state: ReplicaState) -> None:
        """Mirrors update_actor_details; rebuilds the details object."""
        self._actor_details = _StubActorDetails(state, self._replica_id.unique_id)
        self._update_state_calls += 1


# ---------------------------------------------------------------------------
# Container implementations
# ---------------------------------------------------------------------------


class _ContainerBase:
    """Shared helpers."""

    def __init__(self):
        self._replicas: Dict[ReplicaState, List[FakeReplica]] = defaultdict(list)

    def add(self, state: ReplicaState, replica: FakeReplica):
        replica.update_state(state)
        self._replicas[state].append(replica)

    def pop(
        self,
        exclude_version=None,
        states=None,
        max_replicas=math.inf,
    ) -> List[FakeReplica]:
        if states is None:
            states = ALL_REPLICA_STATES
        replicas = []
        for state in states:
            popped = []
            remaining = []
            for replica in self._replicas[state]:
                if len(replicas) + len(popped) == max_replicas:
                    remaining.append(replica)
                else:
                    popped.append(replica)
            self._replicas[state] = remaining
            replicas.extend(popped)
        return replicas


class OldContainer(_ContainerBase):
    """Original stop_replicas: pop everything, re-add non-matching."""

    def stop_replicas(self, replicas_to_stop: Set[ReplicaID]):
        stopped = []
        for replica in self.pop():
            if replica.replica_id in replicas_to_stop:
                # In production this calls _stop_replica(); we just record it.
                stopped.append(replica)
            else:
                self.add(replica.actor_details.state, replica)
        return stopped


class NewContainer(_ContainerBase):
    """New stop_replicas: single-pass remove_many by ID set."""

    def remove(self, replica_ids) -> List[FakeReplica]:
        replica_ids = set(replica_ids)
        removed = []
        remaining_to_find = len(replica_ids)
        for state in ALL_REPLICA_STATES:
            if remaining_to_find == 0:
                break
            found_any = False
            remaining = []
            for replica in self._replicas[state]:
                if remaining_to_find > 0 and replica.replica_id in replica_ids:
                    removed.append(replica)
                    remaining_to_find -= 1
                    found_any = True
                else:
                    remaining.append(replica)
            if found_any:
                self._replicas[state] = remaining
        return removed

    def stop_replicas(self, replicas_to_stop: Set[ReplicaID]):
        return self.remove(replicas_to_stop)


# ---------------------------------------------------------------------------
# Benchmark harness
# ---------------------------------------------------------------------------

DEPLOYMENT_ID = DeploymentID("bench-deploy", "bench-app")
WARMUP_ROUNDS = 3
MEASURE_ROUNDS = 20


def _make_replicas(n: int) -> List[FakeReplica]:
    return [
        FakeReplica(
            ReplicaID(f"r-{i}", DEPLOYMENT_ID),
            ReplicaState.RUNNING,
        )
        for i in range(n)
    ]


def _fill_container(container, replicas: List[FakeReplica]):
    for r in replicas:
        container.add(ReplicaState.RUNNING, r)


def _pick_targets(replicas: List[FakeReplica], k: int) -> Set[ReplicaID]:
    chosen = random.sample(replicas, k)
    return {r.replica_id for r in chosen}


def bench_latency(container_cls, n_replicas: int, num_to_stop: int) -> dict:
    """Returns dict with median, p99, mean latency in microseconds."""
    replicas_master = _make_replicas(n_replicas)
    targets = _pick_targets(replicas_master, num_to_stop)

    timings = []

    for rnd in range(WARMUP_ROUNDS + MEASURE_ROUNDS):
        c = container_cls()
        _fill_container(c, replicas_master)

        gc.disable()
        t0 = time.perf_counter_ns()
        c.stop_replicas(targets)
        t1 = time.perf_counter_ns()
        gc.enable()

        if rnd >= WARMUP_ROUNDS:
            timings.append((t1 - t0) / 1_000)  # ns -> µs

    timings.sort()
    return {
        "median_us": statistics.median(timings),
        "p99_us": timings[int(len(timings) * 0.99)],
        "mean_us": statistics.mean(timings),
    }


def bench_memory(container_cls, n_replicas: int, num_to_stop: int) -> dict:
    """Returns dict with peak memory allocation delta in KiB.

    Averages over multiple runs to reduce noise.
    """
    NUM_MEM_RUNS = 5
    deltas = []
    for _ in range(NUM_MEM_RUNS):
        replicas_master = _make_replicas(n_replicas)
        targets = _pick_targets(replicas_master, num_to_stop)

        c = container_cls()
        _fill_container(c, replicas_master)

        gc.collect()
        tracemalloc.start()

        snap_before = tracemalloc.take_snapshot()
        c.stop_replicas(targets)
        snap_after = tracemalloc.take_snapshot()

        tracemalloc.stop()

        stats = snap_after.compare_to(snap_before, "lineno")
        delta_bytes = sum(s.size_diff for s in stats if s.size_diff > 0)
        deltas.append(delta_bytes / 1024)

    return {"peak_delta_kib": statistics.median(deltas)}


def bench_update_state_calls(n_replicas: int, num_to_stop: int) -> dict:
    """Returns update_state call counts for old vs new."""
    replicas_old = _make_replicas(n_replicas)
    replicas_new = _make_replicas(n_replicas)
    targets = _pick_targets(replicas_old, num_to_stop)

    c_old = OldContainer()
    _fill_container(c_old, replicas_old)
    for r in replicas_old:
        r._update_state_calls = 0

    c_new = NewContainer()
    _fill_container(c_new, replicas_new)
    for r in replicas_new:
        r._update_state_calls = 0

    c_old.stop_replicas(targets)
    c_new.stop_replicas(targets)

    old_calls = sum(r._update_state_calls for r in replicas_old)
    new_calls = sum(r._update_state_calls for r in replicas_new)
    return {"old": old_calls, "new": new_calls}


def run_scenario(label: str, replica_counts: List[int], stop_fn):
    """Run a full latency + memory + update_state scenario.

    Args:
        label: human-readable description (e.g. "stopping 2 replicas").
        replica_counts: list of total replica counts to sweep.
        stop_fn: callable(n) -> number of replicas to stop for that n.
    """
    hdr = (
        f"{'Replicas':>8} {'k':>5}"
        f"  │ {'Old µs':>9} {'New µs':>9} {'Speedup':>7}"
        f"  │ {'Old KiB':>8} {'New KiB':>8} {'Saved':>8}"
        f"  │ {'Old upd':>8} {'New upd':>8}"
    )
    sep = "─" * len(hdr)

    print()
    print(f"  {label}")
    print(sep)
    print(hdr)
    print(sep)

    for n in replica_counts:
        k = stop_fn(n)
        old_lat = bench_latency(OldContainer, n, k)
        new_lat = bench_latency(NewContainer, n, k)
        speedup = old_lat["median_us"] / new_lat["median_us"] if new_lat["median_us"] > 0 else float("inf")
        old_mem = bench_memory(OldContainer, n, k)
        new_mem = bench_memory(NewContainer, n, k)
        saved = old_mem["peak_delta_kib"] - new_mem["peak_delta_kib"]
        us = bench_update_state_calls(n, k)

        print(
            f"{n:>8} {k:>5}"
            f"  │ {old_lat['median_us']:>8.1f}µ {new_lat['median_us']:>8.1f}µ {speedup:>6.1f}x"
            f"  │ {old_mem['peak_delta_kib']:>7.1f}  {new_mem['peak_delta_kib']:>7.1f} {saved:>7.1f}"
            f"  │ {us['old']:>8,} {us['new']:>8,}"
        )

    print(sep)


def main():
    replica_counts = [16, 32, 64, 128, 256, 512, 1024, 2048, 4096]

    print()
    print("=" * 72)
    print("  stop_replicas() micro-benchmark")
    print("  Old = pop all + re-add non-matching  (original)")
    print("  New = single-pass remove by ID set   (optimized)")
    print("=" * 72)

    # Scenario 1: stop 2 replicas (typical downscale)
    run_scenario(
        "SCENARIO 1: Downscale — stop 2 out of N  (typical)",
        replica_counts,
        stop_fn=lambda n: 2,
    )

    # Scenario 2: stop all replicas (full teardown)
    run_scenario(
        "SCENARIO 2: Teardown — stop ALL N replicas  (worst case for old per-ID approach)",
        replica_counts,
        stop_fn=lambda n: n,
    )

    # Scenario 3: stop 10% of replicas
    run_scenario(
        "SCENARIO 3: Moderate downscale — stop 10% of N",
        [64, 128, 256, 512, 1024, 2048, 4096],
        stop_fn=lambda n: max(1, n // 10),
    )

    print()


if __name__ == "__main__":
    main()
Scenario N Old µs New µs Speedup Old ns/r New ns/r
Same-state readd (RUNNING→RUNNING) 16 331 10 33x 20,692 619
128 2,671 78 35x 20,864 605
1,024 20,762 600 35x 20,275 586
4,096 85,301 2,600 33x 20,825 635
State transition (STARTING→RUNNING) 16 322 42 7.7x 20,099 2,619
128 2,623 332 7.9x 20,490 2,591
1,024 20,610 2,701 7.6x 20,126 2,637
4,096 83,353 10,675 7.8x 20,350 2,606
Multi-field no-op (6 fields, same values) 16 328 18 18x 20,480 1,135
128 2,587 135 19x 20,208 1,054
1,024 21,031 1,111 19x 20,538 1,085
4,096 84,994 4,420 19x 20,751 1,079
Multi-field changed (6 fields, new values) 16 329 45 7.2x 20,553 2,838
128 2,649 361 7.3x 20,693 2,821
1,024 20,788 2,926 7.1x 20,301 2,857
4,096 85,915 11,838 7.3x 20,975 2,890

At 4,096 replicas the steady-state tick cost drops from 85 ms → 2.6 ms (same-state readd, the dominant path). Even real state transitions drop from 83 ms → 10.7 ms.

Test plan

  • Existing unit tests pass (test_deployment_state.py)

Related to #60680

…or_details

Signed-off-by: abrar <abrar@anyscale.com>
@abrarsheikh abrarsheikh requested a review from a team as a code owner February 7, 2026 08:38
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant performance optimization to DeploymentReplica.update_actor_details(). By adding an early-exit guard for no-op updates and switching from .dict() reconstruction to Pydantic's more efficient .copy(update=...) method, it greatly reduces overhead in a frequently called function. The changes are well-reasoned and supported by comprehensive benchmarks. My review includes one suggestion to further improve the robustness of the early-exit check by using a sentinel object to handle invalid keyword arguments more consistently.

Signed-off-by: abrar <abrar@anyscale.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant