Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 60 additions & 16 deletions src/exo/master/placement.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from copy import deepcopy
from typing import Sequence

from loguru import logger

from exo.master.placement_utils import (
Cycle,
filter_cycles_by_memory,
Expand Down Expand Up @@ -127,26 +129,68 @@ def place_instance(
if len(cycles_with_sufficient_memory) == 0:
raise ValueError("No cycles found with sufficient memory")

if command.sharding == Sharding.Tensor:
# Asymmetric TP currently only supports Qwen3.5 at the worker level.
asymmetric_tp_families = {"qwen3_5"}

if (
command.sharding == Sharding.AsymmetricTensor
and command.model_card.family not in asymmetric_tp_families
):
raise ValueError(
f"Asymmetric tensor parallelism is not yet supported for "
f"family '{command.model_card.family}'. "
f"Supported: {asymmetric_tp_families}"
)

if command.sharding in (Sharding.Tensor, Sharding.AsymmetricTensor):
if not command.model_card.supports_tensor:
raise ValueError(
f"Requested Tensor sharding but this model does not support tensor parallelism: {command.model_card.model_id}"
)
# TODO: the condition here for tensor parallel is not correct, but it works good enough for now.
kv_heads = command.model_card.num_key_value_heads
cycles_with_sufficient_memory = [
cycle
for cycle in cycles_with_sufficient_memory
if command.model_card.hidden_size % len(cycle) == 0
and (kv_heads is None or kv_heads % len(cycle) == 0)
]
if not cycles_with_sufficient_memory:
raise ValueError(
f"No tensor sharding found for model with "
f"hidden_size={command.model_card.hidden_size}"
f"{f', num_key_value_heads={kv_heads}' if kv_heads is not None else ''}"
f" across candidate cycles"
)
if command.sharding == Sharding.Tensor:
# TODO: the condition here for tensor parallel is not correct, but it works good enough for now.
kv_heads = command.model_card.num_key_value_heads
cycles_with_sufficient_memory = [
cycle
for cycle in cycles_with_sufficient_memory
if command.model_card.hidden_size % len(cycle) == 0
and (kv_heads is None or kv_heads % len(cycle) == 0)
]
if not cycles_with_sufficient_memory:
raise ValueError(
f"No tensor sharding found for model with "
f"hidden_size={command.model_card.hidden_size}"
f"{f', num_key_value_heads={kv_heads}' if kv_heads is not None else ''}"
f" across candidate cycles"
)

# Auto-upgrade to AsymmetricTensor when equal TP won't fit on
# the smallest node but asymmetric split would.
# Only for model families with tested asymmetric TP support.
if command.model_card.family in asymmetric_tp_families:
for cycle in cycles_with_sufficient_memory:
equal_share = command.model_card.storage_size.in_bytes / len(cycle)
min_node_mem = min(
node_memory[nid].ram_available.in_bytes for nid in cycle
)
if equal_share > min_node_mem * 0.9:
# Equal split too tight — try asymmetric
total_mem = sum(
node_memory[nid].ram_available.in_bytes
for nid in cycle
)
if (
command.model_card.storage_size.in_bytes
< total_mem * 0.85
):
logger.info(
"Equal tensor split won't fit on smallest node "
f"({min_node_mem / 1e9:.0f}GB available, "
f"needs {equal_share / 1e9:.0f}GB). "
"Auto-upgrading to AsymmetricTensor."
)
command.sharding = Sharding.AsymmetricTensor
break
if command.sharding == Sharding.Pipeline and command.model_card.model_id == ModelId(
"mlx-community/DeepSeek-V3.1-8bit"
):
Expand Down
86 changes: 86 additions & 0 deletions src/exo/master/placement_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from exo.shared.types.topology import Cycle, RDMAConnection, SocketConnection
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
from exo.shared.types.worker.shards import (
AsymmetricTensorShardMetadata,
CfgShardMetadata,
PipelineShardMetadata,
Sharding,
Expand Down Expand Up @@ -273,6 +274,85 @@ def get_shard_assignments_for_tensor_parallel(
return shard_assignments


def get_shard_assignments_for_asymmetric_tensor_parallel(
model_card: ModelCard,
cycle: Cycle,
node_memory: Mapping[NodeId, MemoryUsage],
) -> ShardAssignments:
"""Create shard assignments for asymmetric tensor parallelism.

Each node gets a ratio of weights proportional to its available memory.
All nodes compute every layer simultaneously.
"""
total_layers = model_card.n_layers
world_size = len(cycle)

# Sort nodes so the largest-memory node is rank 0. The ratio solver
# returns ratios[0] > 0.5, so rank 0 must be the bigger machine.
sorted_nodes = sorted(
cycle, key=lambda nid: node_memory[nid].ram_available.in_bytes, reverse=True
)

# Compute memory fractions (largest first)
total_available = sum(
node_memory[node_id].ram_available.in_bytes for node_id in sorted_nodes
)
memory_fractions = [
node_memory[node_id].ram_available.in_bytes / total_available
for node_id in sorted_nodes
]

from exo.worker.engines.mlx.asymmetric_parallel import find_valid_ratios

ratios = find_valid_ratios(
memory_fractions=memory_fractions,
hidden_size=model_card.hidden_size,
# ModelCard only carries num_key_value_heads, not num_attention_heads.
# hidden_size divisibility is a sufficient constraint for the ratio
# solver; per-head validation happens at load time when the model
# config is available.
num_attention_heads=model_card.hidden_size // 128, # conservative estimate from head_dim=128
num_key_value_heads=model_card.num_key_value_heads or 2,
)
if ratios is None:
raise ValueError(
f"No valid asymmetric ratio found for hidden_size={model_card.hidden_size}"
)

runner_to_shard: dict[RunnerId, ShardMetadata] = {}
node_to_runner: dict[NodeId, RunnerId] = {}

# Store rank-0's ratio on every shard so all ranks agree on the
# same global split point. The worker uses group.rank() to decide
# which slice it keeps.
rank0_ratio = ratios[0]

for i, node_id in enumerate(sorted_nodes):
shard = AsymmetricTensorShardMetadata(
model_card=model_card,
device_rank=i,
world_size=world_size,
start_layer=0,
end_layer=total_layers,
n_layers=total_layers,
ratio=rank0_ratio,
)
runner_id = RunnerId()
runner_to_shard[runner_id] = shard
node_to_runner[node_id] = runner_id

logger.info(
f"Asymmetric TP: ratios={[f'{r:.0%}' for r in ratios]} "
f"across {world_size} nodes"
)

return ShardAssignments(
model_id=model_card.model_id,
runner_to_shard=runner_to_shard,
node_to_runner=node_to_runner,
)


def get_shard_assignments(
model_card: ModelCard,
cycle: Cycle,
Expand All @@ -291,6 +371,12 @@ def get_shard_assignments(
model_card=model_card,
cycle=cycle,
)
case Sharding.AsymmetricTensor:
return get_shard_assignments_for_asymmetric_tensor_parallel(
model_card=model_card,
cycle=cycle,
node_memory=node_memory,
)


def get_mlx_jaccl_devices_matrix(
Expand Down
31 changes: 30 additions & 1 deletion src/exo/shared/types/worker/shards.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

class Sharding(str, Enum):
Tensor = "Tensor"
AsymmetricTensor = "AsymmetricTensor"
Pipeline = "Pipeline"


Expand Down Expand Up @@ -79,6 +80,34 @@ class TensorShardMetadata(BaseShardMetadata):
pass


@final
class AsymmetricTensorShardMetadata(BaseShardMetadata):
"""
Asymmetric tensor parallelism shard metadata.

Unlike standard tensor parallelism which splits weights 50/50 (or equally
across N nodes), asymmetric TP splits weights proportionally to each node's
available memory. This enables heterogeneous clusters (e.g. 128GB + 48GB)
to run models using tensor parallelism where equal splits wouldn't fit.

Each node holds a different fraction of each weight tensor, but ALL nodes
compute every layer simultaneously. The all_sum reduction still works
correctly because (x_a @ W_a^T) + (x_b @ W_b^T) = x @ W^T regardless
of how W is partitioned.
"""

ratio: float = Field(
ge=0.0,
le=1.0,
description="Split point for rank 0, shared across all ranks. "
"e.g. 0.75 means rank 0 gets the first 75% and rank 1 gets the last 25%. "
"Every rank stores the same value so all workers agree on the split.",
)


ShardMetadata: TypeAlias = (
PipelineShardMetadata | CfgShardMetadata | TensorShardMetadata
PipelineShardMetadata
| CfgShardMetadata
| TensorShardMetadata
| AsymmetricTensorShardMetadata
)
Loading