Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions python/ray/serve/_private/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,6 +859,45 @@ class CreatePlacementGroupRequest:
fallback_strategy: Optional[List[Dict[str, Any]]] = None


@PublicAPI(stability="alpha")
@dataclass
class GangContext:
"""Context information for a replica that is part of a gang.

Attributes:
gang_id: Unique identifier for this gang.
rank: This replica's rank within the gang (0-indexed).
world_size: Total number of replicas in this gang.
member_replica_ids: List of replica IDs in this gang, ordered by rank.
"""

gang_id: str
rank: int
world_size: int
member_replica_ids: List[str]


@dataclass
class GangPlacementGroupRequest:
"""Request to prepare gang placement groups for a deployment."""

deployment_id: DeploymentID
gang_size: int
gang_placement_strategy: str
num_replicas_to_add: int
replica_resource_dict: Dict[str, float]


@dataclass
class GangPreparationResult:
"""Result of gang placement group preparation."""

success: bool
error_message: Optional[str] = None
# Map of gang_index -> (PlacementGroup, list of bundle indices for replicas)
gang_pgs: Dict[int, Any] = field(default_factory=dict)


# This error is used to raise when a by-value DeploymentResponse is converted to an
# ObjectRef.
OBJ_REF_NOT_SUPPORTED_ERROR = RuntimeError(
Expand Down
48 changes: 48 additions & 0 deletions python/ray/serve/_private/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
AggregationFunction,
AutoscalingConfig,
DeploymentMode,
GangPlacementStrategy,
GangRuntimeFailurePolicy,
GangSchedulingConfig,
HTTPOptions,
ProxyLocation,
RequestRouterConfig,
Expand All @@ -41,6 +44,9 @@
DeploymentConfig as DeploymentConfigProto,
DeploymentLanguage,
EncodingType as EncodingTypeProto,
GangPlacementStrategy as GangPlacementStrategyProto,
GangRuntimeFailurePolicy as GangRuntimeFailurePolicyProto,
GangSchedulingConfig as GangSchedulingConfigProto,
LoggingConfig as LoggingConfigProto,
ReplicaConfig as ReplicaConfigProto,
RequestRouterConfig as RequestRouterConfigProto,
Expand Down Expand Up @@ -199,6 +205,10 @@ class DeploymentConfig(BaseModel):
default=DEFAULT_CONSTRUCTOR_RETRY_COUNT,
update_type=DeploymentOptionUpdateType.NeedsReconfigure,
)
gang_scheduling_config: Optional[GangSchedulingConfig] = Field(
default=None,
update_type=DeploymentOptionUpdateType.HeavyWeight,
)

# Contains the names of deployment options manually set by the user
user_configured_option_names: Set[str] = set()
Expand Down Expand Up @@ -246,6 +256,20 @@ def validate_max_queued_requests(cls, v):

return v

@validator("gang_scheduling_config", always=True)
def validate_gang_scheduling_config(cls, v, values):
if v is None:
return v

num_replicas = values.get("num_replicas")
if num_replicas % v.gang_size != 0:
raise ValueError(
f"num_replicas ({num_replicas}) must be a multiple of "
f"gang_size ({v.gang_size})."
)
Comment on lines +264 to +268
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The num_replicas value could potentially be None, which would cause a TypeError when the modulo operator is used. While pydantic's default value handling might prevent this, adding a check for num_replicas is not None would make this validator more robust against unexpected None values.

Suggested change
if num_replicas % v.gang_size != 0:
raise ValueError(
f"num_replicas ({num_replicas}) must be a multiple of "
f"gang_size ({v.gang_size})."
)
if num_replicas is not None and num_replicas % v.gang_size != 0:
raise ValueError(
f"num_replicas ({num_replicas}) must be a multiple of "
f"gang_size ({v.gang_size})."
)


return v
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Validator crashes when num_replicas is None

High Severity

The validate_gang_scheduling_config validator performs num_replicas % v.gang_size without checking if num_replicas is None. The num_replicas field is Optional[NonNegativeInt] and can be None when autoscaling is used (e.g. num_replicas="auto"). Additionally, in Pydantic v1, if num_replicas fails its own validation, it won't be present in values, causing values.get("num_replicas") to return None. In either case, None % v.gang_size raises a TypeError.

Fix in Cursor Fix in Web


def needs_pickle(self):
return _needs_pickle(self.deployment_language, self.is_cross_language)

Expand Down Expand Up @@ -295,6 +319,19 @@ def to_proto(self):
data["user_configured_option_names"] = list(
data["user_configured_option_names"]
)
if data.get("gang_scheduling_config"):
gang_config = data["gang_scheduling_config"]
placement_strategy = GangPlacementStrategyProto.Value(
gang_config["gang_placement_strategy"]
)
failure_policy = GangRuntimeFailurePolicyProto.Value(
gang_config["runtime_failure_policy"]
)
data["gang_scheduling_config"] = GangSchedulingConfigProto(
gang_size=gang_config["gang_size"],
gang_placement_strategy=placement_strategy,
runtime_failure_policy=failure_policy,
)
return DeploymentConfigProto(**data)

def to_proto_bytes(self):
Expand Down Expand Up @@ -374,6 +411,17 @@ def from_proto(cls, proto: DeploymentConfigProto):
data["logging_config"]["encoding"] = EncodingTypeProto.Name(
data["logging_config"]["encoding"]
)
if "gang_scheduling_config" in data and data["gang_scheduling_config"]:
gang_config = data["gang_scheduling_config"]
gang_config["gang_placement_strategy"] = GangPlacementStrategy(
GangPlacementStrategyProto.Name(gang_config["gang_placement_strategy"])
)
gang_config["runtime_failure_policy"] = GangRuntimeFailurePolicy(
GangRuntimeFailurePolicyProto.Name(gang_config["runtime_failure_policy"])
)
data["gang_scheduling_config"] = GangSchedulingConfig(**gang_config)
else:
data.pop("gang_scheduling_config", None)

return cls(**data)

Expand Down
134 changes: 133 additions & 1 deletion python/ray/serve/_private/deployment_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
import logging
import sys
import uuid
import warnings
from abc import ABC, abstractmethod
from collections import defaultdict
Expand All @@ -15,6 +16,8 @@
from ray.serve._private.common import (
CreatePlacementGroupRequest,
DeploymentID,
GangPlacementGroupRequest,
GangPreparationResult,
ReplicaID,
)
from ray.serve._private.config import ReplicaConfig
Expand Down Expand Up @@ -159,6 +162,10 @@ class ReplicaSchedulingRequest:
placement_group_bundle_label_selector: Optional[List[Dict[str, str]]] = None
placement_group_fallback_strategy: Optional[List[Dict[str, Any]]] = None
max_replicas_per_node: Optional[int] = None
# Gang scheduling fields - if set, replica should be scheduled on
# the pre-created gang placement group at the specified bundle index.
gang_placement_group: Optional[Any] = None # PlacementGroup
gang_bundle_index: Optional[int] = None

@property
def required_resources(self) -> Resources:
Expand Down Expand Up @@ -562,7 +569,22 @@ def _schedule_replica(
placement_group = None

scheduling_strategy = default_scheduling_strategy
if scheduling_request.placement_group_bundles is not None:

# Gang scheduling path - use pre-created gang placement group
if scheduling_request.gang_placement_group is not None:
placement_group = scheduling_request.gang_placement_group
scheduling_strategy = PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_bundle_index=scheduling_request.gang_bundle_index,
placement_group_capture_child_tasks=True,
)
target_labels = None
target_node_id = None
logger.debug(
f"Scheduling {replica_id} on gang PG "
f"(bundle_index={scheduling_request.gang_bundle_index})."
)
elif scheduling_request.placement_group_bundles is not None:
placement_group_strategy = (
scheduling_request.placement_group_strategy
if scheduling_request.placement_group_strategy
Expand Down Expand Up @@ -649,6 +671,14 @@ def get_node_to_compact(
"""Returns a node ID to be compacted and a compaction deadlne."""
raise NotImplementedError

@abstractmethod
def schedule_gang_placement_groups(
self,
gang_requests: Dict[DeploymentID, GangPlacementGroupRequest],
) -> Dict[DeploymentID, GangPreparationResult]:
"""Reserve resources for gang scheduling."""
raise NotImplementedError


class DefaultDeploymentScheduler(DeploymentScheduler):
def schedule(
Expand Down Expand Up @@ -953,3 +983,105 @@ def get_node_to_compact(
self, allow_new_compaction: bool
) -> Optional[Tuple[str, float]]:
return None

def schedule_gang_placement_groups(
self,
gang_requests: Dict[DeploymentID, GangPlacementGroupRequest],
) -> Dict[DeploymentID, GangPreparationResult]:
"""Pre-create placement groups for gang scheduling.

Creates gang placement groups before replicas are created, allowing
the scheduler to verify resource feasibility upfront.
"""
results = {}

for deployment_id, request in gang_requests.items():
result = self._prepare_gangs_for_deployment(deployment_id, request)
results[deployment_id] = result

return results

def _prepare_gangs_for_deployment(
self,
deployment_id: DeploymentID,
request: GangPlacementGroupRequest,
) -> GangPreparationResult:
"""Create gang placement groups for a single deployment.

Args:
deployment_id: The deployment to create gangs for.
request: Contains gang config and number of replicas to add.

Returns:
GangPreparationResult with success status and created PGs.
"""
gang_size = request.gang_size
num_gangs_needed = request.num_replicas_to_add // gang_size

gang_pgs = {}
created_pgs = [] # Track for cleanup on failure

for gang_index in range(num_gangs_needed):
# Build bundles - each bundle is for one replica in the gang
bundles = [
request.replica_resource_dict.copy()
for _ in range(gang_size)
]

pg_name = (
f"gang_{deployment_id.app_name}_{deployment_id.name}"
f"_{gang_index}_{uuid.uuid4().hex[:8]}"
)
strategy = request.gang_placement_strategy

try:
pg = self._create_placement_group_fn(
CreatePlacementGroupRequest(
bundles=bundles,
strategy=strategy,
target_node_id=None,
name=pg_name,
bundle_label_selector=None,
)
)
created_pgs.append(pg)

# Wait for placement group to be created with a timeout
# to check feasibility
GANG_PG_TIMEOUT_S = 30
if pg.wait(timeout_seconds=GANG_PG_TIMEOUT_S):
# PG is ready - store with bundle indices for replicas
gang_pgs[gang_index] = pg
else:
# PG creation timed out - infeasible
self._cleanup_gang_pgs(created_pgs)
pg_table = ray.util.placement_group_table(pg)
state = pg_table.get("state", "UNKNOWN")
return GangPreparationResult(
success=False,
error_message=(
f"Gang placement group '{pg_name}' is infeasible. "
f"State: {state}. Cluster may not have enough resources "
f"to schedule {gang_size} replicas together."
),
)

except Exception as e:
self._cleanup_gang_pgs(created_pgs)
logger.exception(
f"Failed to create gang placement group for {deployment_id}."
)
return GangPreparationResult(
success=False,
error_message=f"Failed to create gang placement group: {str(e)}",
)

return GangPreparationResult(success=True, gang_pgs=gang_pgs)

def _cleanup_gang_pgs(self, pgs: List[Any]) -> None:
"""Clean up placement groups on failure."""
for pg in pgs:
try:
ray.util.remove_placement_group(pg)
except Exception:
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Catching a broad Exception and passing silently can hide important issues during cleanup. While cleanup should be robust, it's better to at least log the exception to aid in debugging potential problems. For example, if there's a permission issue or a problem with the GCS connection, we would want to know about it.

Suggested change
except Exception:
pass
except Exception as e:
logger.warning(f"Failed to remove placement group {pg.id}: {e}")

Loading