Skip to content

Commit 1a4f124

Browse files
Persist GangContext in ReplicaContext
Signed-off-by: jeffreywang <jeffreywang@anyscale.com>
1 parent 95d7e5c commit 1a4f124

File tree

5 files changed

+165
-33
lines changed

5 files changed

+165
-33
lines changed

python/ray/serve/_private/common.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -859,6 +859,24 @@ class CreatePlacementGroupRequest:
859859
fallback_strategy: Optional[List[Dict[str, Any]]] = None
860860

861861

862+
@PublicAPI(stability="alpha")
863+
@dataclass
864+
class GangContext:
865+
"""Context information for a replica that is part of a gang.
866+
867+
Attributes:
868+
gang_id: Unique identifier for this gang.
869+
rank: This replica's rank within the gang (0-indexed).
870+
world_size: Total number of replicas in this gang.
871+
member_replica_ids: List of replica IDs in this gang, ordered by rank.
872+
"""
873+
874+
gang_id: str
875+
rank: int
876+
world_size: int
877+
member_replica_ids: List[str]
878+
879+
862880
@dataclass
863881
class GangPlacementGroupRequest:
864882
"""Request to prepare gang placement groups for a deployment.

python/ray/serve/_private/deployment_state.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
DeploymentStatusTrigger,
2929
DeploymentTargetInfo,
3030
Duration,
31+
GangContext,
3132
GangPlacementGroupRequest,
3233
GangPreparationResult,
3334
ReplicaID,
@@ -400,6 +401,10 @@ def deployment_name(self) -> str:
400401
def rank(self) -> Optional[ReplicaRank]:
401402
return self._rank
402403

404+
@property
405+
def gang_context(self) -> Optional[GangContext]:
406+
return self._gang_context
407+
403408
@property
404409
def app_name(self) -> str:
405410
return self._deployment_id.app_name
@@ -578,6 +583,7 @@ def start(
578583
assign_rank_callback: Callable[[ReplicaID], ReplicaRank],
579584
gang_placement_group: Optional[PlacementGroup] = None,
580585
gang_bundle_index: Optional[int] = None,
586+
gang_context: Optional[GangContext] = None,
581587
) -> ReplicaSchedulingRequest:
582588
"""Start the current DeploymentReplica instance.
583589
@@ -589,12 +595,14 @@ def start(
589595
assign_rank_callback: Callback to assign rank to the replica.
590596
gang_placement_group: Pre-created gang PG to schedule this replica on.
591597
gang_bundle_index: Bundle index within the gang PG for this replica.
598+
gang_context: Gang context for this replica (if part of a gang).
592599
"""
593600
self._assign_rank_callback = assign_rank_callback
594601
self._actor_resources = deployment_info.replica_config.resource_dict
595602
self._ingress = deployment_info.ingress
596603
self._gang_placement_group = gang_placement_group
597604
self._gang_bundle_index = gang_bundle_index
605+
self._gang_context = gang_context
598606
# it is currently not possible to create a placement group
599607
# with no resources (https://github.com/ray-project/ray/issues/20401)
600608
self._deployment_is_cross_language = (
@@ -894,7 +902,7 @@ def check_ready(self) -> Tuple[ReplicaStartupStatus, Optional[str]]:
894902
self._replica_id.unique_id, self._node_id
895903
)
896904
self._ready_obj_ref = replica_ready_check_func.remote(
897-
deployment_config, self._rank
905+
deployment_config, self._rank, self._gang_context
898906
)
899907

900908
return ReplicaStartupStatus.PENDING_INITIALIZATION, None
@@ -1376,6 +1384,7 @@ def start(
13761384
assign_rank_callback: Callable[[ReplicaID], ReplicaRank],
13771385
gang_placement_group: Optional[PlacementGroup] = None,
13781386
gang_bundle_index: Optional[int] = None,
1387+
gang_context: Optional[GangContext] = None,
13791388
) -> ReplicaSchedulingRequest:
13801389
"""
13811390
Start a new actor for current DeploymentReplica instance.
@@ -1385,12 +1394,14 @@ def start(
13851394
assign_rank_callback: Callback to assign rank to the replica.
13861395
gang_placement_group: Pre-created gang PG to schedule this replica on.
13871396
gang_bundle_index: Bundle index within the gang PG for this replica.
1397+
gang_context: Gang context for this replica (if part of a gang).
13881398
"""
13891399
replica_scheduling_request = self._actor.start(
13901400
deployment_info,
13911401
assign_rank_callback=assign_rank_callback,
13921402
gang_placement_group=gang_placement_group,
13931403
gang_bundle_index=gang_bundle_index,
1404+
gang_context=gang_context,
13941405
)
13951406
self._start_time = time.time()
13961407
self._logged_shutdown_message = False
@@ -1431,6 +1442,11 @@ def rank(self) -> Optional[ReplicaRank]:
14311442
"""Get the rank assigned to the replica."""
14321443
return self._actor.rank
14331444

1445+
@property
1446+
def gang_context(self) -> Optional[GangContext]:
1447+
"""Get the gang context for this replica (if part of a gang)."""
1448+
return self._actor.gang_context
1449+
14341450
def check_started(
14351451
self,
14361452
) -> Tuple[ReplicaStartupStatus, Optional[str], Optional[float]]:
@@ -3060,20 +3076,35 @@ def _add_replicas_with_gang_scheduling(
30603076
"This should not happen."
30613077
)
30623078

3063-
for bundle_index in range(gang_size):
3064-
replica_id = ReplicaID(get_random_string(), deployment_id=self._id)
3079+
# Pre-generate replica IDs for all members of this gang
3080+
gang_id = get_random_string()
3081+
member_replica_ids = [
3082+
ReplicaID(get_random_string(), deployment_id=self._id)
3083+
for _ in range(gang_size)
3084+
]
3085+
3086+
for bundle_index, replica_id in enumerate(member_replica_ids):
3087+
gang_context = GangContext(
3088+
gang_id=gang_id,
3089+
rank=bundle_index,
3090+
world_size=gang_size,
3091+
member_replica_ids=[
3092+
r.unique_id for r in member_replica_ids
3093+
],
3094+
)
30653095

30663096
new_deployment_replica = DeploymentReplica(
30673097
replica_id,
30683098
self._target_state.version,
30693099
)
30703100

3071-
# Start the replica with gang PG information
3101+
# Start the replica with gang PG and gang context
30723102
scheduling_request = new_deployment_replica.start(
30733103
self._target_state.info,
30743104
assign_rank_callback=self._rank_manager.assign_rank,
30753105
gang_placement_group=gang_pg,
30763106
gang_bundle_index=bundle_index,
3107+
gang_context=gang_context,
30773108
)
30783109

30793110
upscale.append(scheduling_request)

python/ray/serve/_private/replica.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -934,6 +934,9 @@ def __init__(
934934
# Flipped to `True` once graceful shutdown is initiated. May be used by replica
935935
# subclass implementations.
936936
self._shutting_down = False
937+
# Gang context for this replica (if part of a gang).
938+
# Populated during initialize_and_get_metadata.
939+
self._gang_context: Optional["GangContext"] = None
937940

938941
# Will be populated with the wrapped ASGI app if the user callable is an
939942
# `ASGIAppReplicaWrapper` (i.e., they are using the FastAPI integration).
@@ -1070,6 +1073,7 @@ def register_handle_callback(deployment_id: DeploymentID) -> None:
10701073
rank=rank,
10711074
world_size=world_size,
10721075
handle_registration_callback=register_handle_callback,
1076+
gang_context=self._gang_context,
10731077
)
10741078

10751079
def _configure_logger_and_profilers(
@@ -1382,8 +1386,13 @@ async def _on_initialized(self):
13821386
raise NotImplementedError
13831387

13841388
async def initialize(
1385-
self, deployment_config: Optional[DeploymentConfig], rank: Optional[ReplicaRank]
1389+
self,
1390+
deployment_config: Optional[DeploymentConfig],
1391+
rank: Optional[ReplicaRank],
1392+
gang_context: Optional["GangContext"] = None,
13861393
):
1394+
if gang_context is not None:
1395+
self._gang_context = gang_context
13871396
if rank is not None:
13881397
self._rank = rank
13891398
self._set_internal_replica_context(
@@ -2427,7 +2436,10 @@ def list_outbound_deployments(self) -> Optional[List[DeploymentID]]:
24272436
return self._replica_impl.list_outbound_deployments()
24282437

24292438
async def initialize_and_get_metadata(
2430-
self, deployment_config: DeploymentConfig = None, rank: ReplicaRank = None
2439+
self,
2440+
deployment_config: DeploymentConfig = None,
2441+
rank: ReplicaRank = None,
2442+
gang_context: "GangContext" = None,
24312443
) -> ReplicaMetadata:
24322444
"""Handles initializing the replica.
24332445
@@ -2440,7 +2452,7 @@ async def initialize_and_get_metadata(
24402452
"""
24412453
# Unused `_after` argument is for scheduling: passing an ObjectRef
24422454
# allows delaying this call until after the `_after` call has returned.
2443-
await self._replica_impl.initialize(deployment_config, rank)
2455+
await self._replica_impl.initialize(deployment_config, rank, gang_context)
24442456
return self._replica_impl.get_metadata()
24452457

24462458
async def check_health(self):

python/ray/serve/context.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import ray
1414
from ray.exceptions import RayActorError
1515
from ray.serve._private.client import ServeControllerClient
16-
from ray.serve._private.common import DeploymentID, ReplicaID
16+
from ray.serve._private.common import DeploymentID, GangContext, ReplicaID
1717
from ray.serve._private.config import DeploymentConfig
1818
from ray.serve._private.constants import (
1919
SERVE_CONTROLLER_NAME,
@@ -24,32 +24,14 @@
2424
from ray.serve.exceptions import RayServeException
2525
from ray.serve.grpc_util import RayServegRPCContext
2626
from ray.serve.schema import ReplicaRank
27-
from ray.util.annotations import DeveloperAPI, PublicAPI
27+
from ray.util.annotations import DeveloperAPI
2828

2929
logger = logging.getLogger(SERVE_LOGGER_NAME)
3030

3131
_INTERNAL_REPLICA_CONTEXT: "ReplicaContext" = None
3232
_global_client: ServeControllerClient = None
3333

3434

35-
@PublicAPI(stability="alpha")
36-
@dataclass
37-
class GangContext:
38-
"""Context information for a replica that is part of a gang.
39-
40-
Attributes:
41-
gang_id: Unique identifier for this gang.
42-
rank: This replica's rank within the gang (0-indexed).
43-
world_size: Total number of replicas in this gang.
44-
member_replica_ids: List of replica IDs in this gang, ordered by rank.
45-
"""
46-
47-
gang_id: str
48-
rank: int
49-
world_size: int
50-
member_replica_ids: List[str]
51-
52-
5335
@DeveloperAPI
5436
@dataclass
5537
class ReplicaContext:

python/ray/serve/tests/test_deployment_scheduler.py

Lines changed: 95 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -736,8 +736,7 @@ class TestGangScheduling:
736736
"""Tests for gang scheduling with placement groups."""
737737

738738
def test_sufficient_resources(self, ray_start_cluster):
739-
"""Verifies that gang scheduling succeeds when cluster has sufficient resources.
740-
"""
739+
"""Verifies that gang scheduling succeeds when cluster has sufficient resources."""
741740
cluster = ray_start_cluster
742741
cluster.add_node(num_cpus=4)
743742
cluster.add_node(num_cpus=4)
@@ -771,8 +770,7 @@ def GangDeployment():
771770
serve.shutdown()
772771

773772
def test_insufficient_resources(self, ray_start_cluster):
774-
"""Verifies that gang scheduling fails when cluster lacks resources.
775-
"""
773+
"""Verifies that gang scheduling fails when cluster lacks resources."""
776774
cluster = ray_start_cluster
777775
cluster.add_node(num_cpus=4)
778776
cluster.add_node(num_cpus=4)
@@ -807,8 +805,7 @@ def check_deploy_failed():
807805
serve.shutdown()
808806

809807
def test_pack_strategy(self, ray_start_cluster):
810-
"""Verifies that PACK strategy places gang replicas on the same node.
811-
"""
808+
"""Verifies that PACK strategy places gang replicas on the same node."""
812809
cluster = ray_start_cluster
813810
cluster.add_node(num_cpus=4)
814811
cluster.add_node(num_cpus=4)
@@ -895,5 +892,97 @@ def SpreadDeployment():
895892
serve.shutdown()
896893

897894

895+
def test_gang_context_populated(self, ray_start_cluster):
896+
"""Verifies GangContext is correctly populated in ReplicaContext."""
897+
cluster = ray_start_cluster
898+
cluster.add_node(num_cpus=4)
899+
cluster.wait_for_nodes()
900+
ray.init(address=cluster.address)
901+
serve.start()
902+
903+
@serve.deployment
904+
class GangContextDeployment:
905+
def __call__(self):
906+
ctx = ray.serve.context._get_internal_replica_context()
907+
gc = ctx.gang_context
908+
if gc is None:
909+
return None
910+
return {
911+
"gang_id": gc.gang_id,
912+
"rank": gc.rank,
913+
"world_size": gc.world_size,
914+
"member_replica_ids": gc.member_replica_ids,
915+
"replica_id": ctx.replica_id.unique_id,
916+
}
917+
918+
app = GangContextDeployment.options(
919+
num_replicas=4,
920+
ray_actor_options={"num_cpus": 1},
921+
gang_scheduling_config=GangSchedulingConfig(gang_size=2),
922+
).bind()
923+
924+
handle = serve.run(app, name="gang_context_app")
925+
wait_for_condition(
926+
check_apps_running,
927+
apps=["gang_context_app"],
928+
timeout=60,
929+
)
930+
931+
# Collect gang contexts from all replicas
932+
# Query enough times to hit all 4 replicas
933+
contexts_by_replica = {}
934+
for _ in range(100):
935+
result = handle.remote().result()
936+
assert result is not None
937+
replica_id = result["replica_id"]
938+
if replica_id not in contexts_by_replica:
939+
contexts_by_replica[replica_id] = result
940+
if len(contexts_by_replica) == 4:
941+
break
942+
assert len(contexts_by_replica) == 4
943+
944+
# Group replicas by gang_id
945+
gangs = {}
946+
for replica_id, ctx in contexts_by_replica.items():
947+
gang_id = ctx["gang_id"]
948+
gangs.setdefault(gang_id, []).append(ctx)
949+
950+
# Should have exactly 2 gangs
951+
assert len(gangs) == 2
952+
953+
for gang_id, members in gangs.items():
954+
# Each gang should have exactly 2 replicas
955+
assert len(members) == 2
956+
957+
# All members should have the same world_size
958+
assert all(member["world_size"] == 2 for member in members)
959+
960+
# All members should have the same member_replica_ids
961+
assert members[0]["member_replica_ids"] == members[1]["member_replica_ids"]
962+
963+
# member_replica_ids should contain exactly the 2 replica IDs in this gang
964+
expected_ids = sorted([m["replica_id"] for m in members])
965+
actual_ids = sorted(members[0]["member_replica_ids"])
966+
assert actual_ids == expected_ids
967+
968+
# Ranks within the gang should be {0, 1}
969+
ranks = sorted([m["rank"] for m in members])
970+
assert ranks == [0, 1]
971+
972+
# Across gangs: gang_ids should be different (already guaranteed by dict keys)
973+
gang_ids = list(gangs.keys())
974+
assert gang_ids[0] != gang_ids[1]
975+
976+
# Across gangs: member_replica_ids should be different
977+
gang_members_list = list(gangs.values())
978+
assert (
979+
sorted(gang_members_list[0][0]["member_replica_ids"])
980+
!= sorted(gang_members_list[1][0]["member_replica_ids"])
981+
)
982+
983+
serve.delete("gang_context_app")
984+
serve.shutdown()
985+
986+
898987
if __name__ == "__main__":
899988
sys.exit(pytest.main(["-v", "-s", __file__]))

0 commit comments

Comments
 (0)