diff --git a/src/exo/master/placement.py b/src/exo/master/placement.py index f9380693f9..f2411f2eaa 100644 --- a/src/exo/master/placement.py +++ b/src/exo/master/placement.py @@ -3,6 +3,8 @@ from copy import deepcopy from typing import Sequence +from loguru import logger + from exo.master.placement_utils import ( Cycle, filter_cycles_by_memory, @@ -127,26 +129,68 @@ def place_instance( if len(cycles_with_sufficient_memory) == 0: raise ValueError("No cycles found with sufficient memory") - if command.sharding == Sharding.Tensor: + # Asymmetric TP currently only supports Qwen3.5 at the worker level. + asymmetric_tp_families = {"qwen3_5"} + + if ( + command.sharding == Sharding.AsymmetricTensor + and command.model_card.family not in asymmetric_tp_families + ): + raise ValueError( + f"Asymmetric tensor parallelism is not yet supported for " + f"family '{command.model_card.family}'. " + f"Supported: {asymmetric_tp_families}" + ) + + if command.sharding in (Sharding.Tensor, Sharding.AsymmetricTensor): if not command.model_card.supports_tensor: raise ValueError( f"Requested Tensor sharding but this model does not support tensor parallelism: {command.model_card.model_id}" ) - # TODO: the condition here for tensor parallel is not correct, but it works good enough for now. - kv_heads = command.model_card.num_key_value_heads - cycles_with_sufficient_memory = [ - cycle - for cycle in cycles_with_sufficient_memory - if command.model_card.hidden_size % len(cycle) == 0 - and (kv_heads is None or kv_heads % len(cycle) == 0) - ] - if not cycles_with_sufficient_memory: - raise ValueError( - f"No tensor sharding found for model with " - f"hidden_size={command.model_card.hidden_size}" - f"{f', num_key_value_heads={kv_heads}' if kv_heads is not None else ''}" - f" across candidate cycles" - ) + if command.sharding == Sharding.Tensor: + # TODO: the condition here for tensor parallel is not correct, but it works good enough for now. + kv_heads = command.model_card.num_key_value_heads + cycles_with_sufficient_memory = [ + cycle + for cycle in cycles_with_sufficient_memory + if command.model_card.hidden_size % len(cycle) == 0 + and (kv_heads is None or kv_heads % len(cycle) == 0) + ] + if not cycles_with_sufficient_memory: + raise ValueError( + f"No tensor sharding found for model with " + f"hidden_size={command.model_card.hidden_size}" + f"{f', num_key_value_heads={kv_heads}' if kv_heads is not None else ''}" + f" across candidate cycles" + ) + + # Auto-upgrade to AsymmetricTensor when equal TP won't fit on + # the smallest node but asymmetric split would. + # Only for model families with tested asymmetric TP support. + if command.model_card.family in asymmetric_tp_families: + for cycle in cycles_with_sufficient_memory: + equal_share = command.model_card.storage_size.in_bytes / len(cycle) + min_node_mem = min( + node_memory[nid].ram_available.in_bytes for nid in cycle + ) + if equal_share > min_node_mem * 0.9: + # Equal split too tight — try asymmetric + total_mem = sum( + node_memory[nid].ram_available.in_bytes + for nid in cycle + ) + if ( + command.model_card.storage_size.in_bytes + < total_mem * 0.85 + ): + logger.info( + "Equal tensor split won't fit on smallest node " + f"({min_node_mem / 1e9:.0f}GB available, " + f"needs {equal_share / 1e9:.0f}GB). " + "Auto-upgrading to AsymmetricTensor." + ) + command.sharding = Sharding.AsymmetricTensor + break if command.sharding == Sharding.Pipeline and command.model_card.model_id == ModelId( "mlx-community/DeepSeek-V3.1-8bit" ): diff --git a/src/exo/master/placement_utils.py b/src/exo/master/placement_utils.py index eb78c8fa5e..9912c9da3d 100644 --- a/src/exo/master/placement_utils.py +++ b/src/exo/master/placement_utils.py @@ -10,6 +10,7 @@ from exo.shared.types.topology import Cycle, RDMAConnection, SocketConnection from exo.shared.types.worker.runners import RunnerId, ShardAssignments from exo.shared.types.worker.shards import ( + AsymmetricTensorShardMetadata, CfgShardMetadata, PipelineShardMetadata, Sharding, @@ -273,6 +274,85 @@ def get_shard_assignments_for_tensor_parallel( return shard_assignments +def get_shard_assignments_for_asymmetric_tensor_parallel( + model_card: ModelCard, + cycle: Cycle, + node_memory: Mapping[NodeId, MemoryUsage], +) -> ShardAssignments: + """Create shard assignments for asymmetric tensor parallelism. + + Each node gets a ratio of weights proportional to its available memory. + All nodes compute every layer simultaneously. + """ + total_layers = model_card.n_layers + world_size = len(cycle) + + # Sort nodes so the largest-memory node is rank 0. The ratio solver + # returns ratios[0] > 0.5, so rank 0 must be the bigger machine. + sorted_nodes = sorted( + cycle, key=lambda nid: node_memory[nid].ram_available.in_bytes, reverse=True + ) + + # Compute memory fractions (largest first) + total_available = sum( + node_memory[node_id].ram_available.in_bytes for node_id in sorted_nodes + ) + memory_fractions = [ + node_memory[node_id].ram_available.in_bytes / total_available + for node_id in sorted_nodes + ] + + from exo.worker.engines.mlx.asymmetric_parallel import find_valid_ratios + + ratios = find_valid_ratios( + memory_fractions=memory_fractions, + hidden_size=model_card.hidden_size, + # ModelCard only carries num_key_value_heads, not num_attention_heads. + # hidden_size divisibility is a sufficient constraint for the ratio + # solver; per-head validation happens at load time when the model + # config is available. + num_attention_heads=model_card.hidden_size // 128, # conservative estimate from head_dim=128 + num_key_value_heads=model_card.num_key_value_heads or 2, + ) + if ratios is None: + raise ValueError( + f"No valid asymmetric ratio found for hidden_size={model_card.hidden_size}" + ) + + runner_to_shard: dict[RunnerId, ShardMetadata] = {} + node_to_runner: dict[NodeId, RunnerId] = {} + + # Store rank-0's ratio on every shard so all ranks agree on the + # same global split point. The worker uses group.rank() to decide + # which slice it keeps. + rank0_ratio = ratios[0] + + for i, node_id in enumerate(sorted_nodes): + shard = AsymmetricTensorShardMetadata( + model_card=model_card, + device_rank=i, + world_size=world_size, + start_layer=0, + end_layer=total_layers, + n_layers=total_layers, + ratio=rank0_ratio, + ) + runner_id = RunnerId() + runner_to_shard[runner_id] = shard + node_to_runner[node_id] = runner_id + + logger.info( + f"Asymmetric TP: ratios={[f'{r:.0%}' for r in ratios]} " + f"across {world_size} nodes" + ) + + return ShardAssignments( + model_id=model_card.model_id, + runner_to_shard=runner_to_shard, + node_to_runner=node_to_runner, + ) + + def get_shard_assignments( model_card: ModelCard, cycle: Cycle, @@ -291,6 +371,12 @@ def get_shard_assignments( model_card=model_card, cycle=cycle, ) + case Sharding.AsymmetricTensor: + return get_shard_assignments_for_asymmetric_tensor_parallel( + model_card=model_card, + cycle=cycle, + node_memory=node_memory, + ) def get_mlx_jaccl_devices_matrix( diff --git a/src/exo/shared/types/worker/shards.py b/src/exo/shared/types/worker/shards.py index 59a6c54eb0..112f6377a7 100644 --- a/src/exo/shared/types/worker/shards.py +++ b/src/exo/shared/types/worker/shards.py @@ -9,6 +9,7 @@ class Sharding(str, Enum): Tensor = "Tensor" + AsymmetricTensor = "AsymmetricTensor" Pipeline = "Pipeline" @@ -79,6 +80,34 @@ class TensorShardMetadata(BaseShardMetadata): pass +@final +class AsymmetricTensorShardMetadata(BaseShardMetadata): + """ + Asymmetric tensor parallelism shard metadata. + + Unlike standard tensor parallelism which splits weights 50/50 (or equally + across N nodes), asymmetric TP splits weights proportionally to each node's + available memory. This enables heterogeneous clusters (e.g. 128GB + 48GB) + to run models using tensor parallelism where equal splits wouldn't fit. + + Each node holds a different fraction of each weight tensor, but ALL nodes + compute every layer simultaneously. The all_sum reduction still works + correctly because (x_a @ W_a^T) + (x_b @ W_b^T) = x @ W^T regardless + of how W is partitioned. + """ + + ratio: float = Field( + ge=0.0, + le=1.0, + description="Split point for rank 0, shared across all ranks. " + "e.g. 0.75 means rank 0 gets the first 75% and rank 1 gets the last 25%. " + "Every rank stores the same value so all workers agree on the split.", + ) + + ShardMetadata: TypeAlias = ( - PipelineShardMetadata | CfgShardMetadata | TensorShardMetadata + PipelineShardMetadata + | CfgShardMetadata + | TensorShardMetadata + | AsymmetricTensorShardMetadata ) diff --git a/src/exo/worker/engines/mlx/asymmetric_parallel.py b/src/exo/worker/engines/mlx/asymmetric_parallel.py new file mode 100644 index 0000000000..e8d52ccbd6 --- /dev/null +++ b/src/exo/worker/engines/mlx/asymmetric_parallel.py @@ -0,0 +1,346 @@ +""" +Asymmetric Tensor Parallelism for heterogeneous clusters. + +When nodes have different amounts of RAM, standard 50/50 tensor parallelism +fails because the smaller node can't hold half the weights. Asymmetric TP +splits each weight tensor proportionally to available memory (e.g. 75/25) +so both nodes compute every layer simultaneously. + +Mathematical correctness: + Column parallel: y = x @ [W_a; W_b]^T = [x @ W_a^T, x @ W_b^T] + Row parallel: y = x_a @ W_a^T + x_b @ W_b^T = x @ W^T (via all_sum) + Both hold regardless of the split ratio. + +Usage: + asymmetric_tensor_auto_parallel(model, group, ratios=[0.75, 0.25]) +""" +# pyright: reportAny=false, reportUnknownMemberType=false, reportUnknownArgumentType=false, reportUnknownVariableType=false + +from __future__ import annotations + +from typing import Any + +import mlx.core as mx +import mlx.nn as nn +from mlx_lm.models.qwen3_5 import DecoderLayer as Qwen3_5DecoderLayer +from mlx_lm.models.qwen3_5 import GatedDeltaNet +from mlx_lm.models.qwen3_5 import SparseMoeBlock as Qwen3_5SparseMoeBlock +from mlx_lm.models.qwen3_next import Qwen3NextAttention as Attention +from mlx_lm.models.qwen3_next import Qwen3NextMLP as Qwen3NextMLP +from mlx_lm.models.qwen3_next import Qwen3NextSparseMoeBlock as SparseMoeBlock + +try: + from exo.shared.logging import logger +except ImportError: + import logging + + logger = logging.getLogger(__name__) + + +def find_valid_ratios( + memory_fractions: list[float], + hidden_size: int, + num_attention_heads: int, + num_key_value_heads: int, + num_experts: int = 0, + moe_intermediate_size: int = 0, + linear_num_value_heads: int = 0, + linear_num_key_heads: int = 0, + quantization_group_size: int = 64, +) -> list[float] | None: + """ + Find valid split ratios for asymmetric TP given model dimensions and memory fractions. + + A valid ratio must produce integer dimensions for all split tensors, + and all split dimensions must be divisible by the quantization group size. + + Returns a list of ratios (one per node) that sum to 1.0, or None if no valid + ratio exists. Currently supports 2 nodes only. + """ + if len(memory_fractions) != 2: + logger.warning("Asymmetric TP currently only supports 2 nodes") + return None + + # Key dimensions that must split cleanly + key_dims = [ + num_attention_heads, + hidden_size, + ] + if linear_num_value_heads > 0: + key_dims.extend([linear_num_value_heads, linear_num_key_heads]) + if num_experts > 0 and moe_intermediate_size > 0: + key_dims.append(moe_intermediate_size) + + target_ratio = memory_fractions[0] + + # Try ratios of the form n/d where d is a power of 2 or common denominator + # that produces clean splits. Test denominators 2..32. + best_ratio = None + best_distance = float("inf") + + for denom in [2, 4, 8, 16, 32]: + for numer in range(1, denom): + ratio = numer / denom + if ratio <= 0.5 or ratio > 0.95: + continue + + # Check all dimensions split cleanly + valid = True + for dim in key_dims: + # dim * ratio must be EXACTLY integer (for head counts) + exact = dim * ratio + if exact != int(exact): + valid = False + break + a = int(exact) + b = dim - a + if a <= 0 or b <= 0: + valid = False + break + # For quantized weights, split dims must be divisible by 8 + if dim > quantization_group_size and (a % 8 != 0 or b % 8 != 0): + valid = False + break + + if valid: + distance = abs(ratio - target_ratio) + if distance < best_distance: + best_distance = distance + best_ratio = ratio + + if best_ratio is None: + return None + + return [best_ratio, 1.0 - best_ratio] + + +def _split_at(tensor: mx.array, axis: int, ratio: float) -> tuple[mx.array, mx.array]: + """Split tensor at ratio point along axis.""" + sp = int(tensor.shape[axis] * ratio) + parts = mx.split(tensor, [sp], axis=axis) + return mx.contiguous(parts[0]), mx.contiguous(parts[1]) + + +def _my_shard(tensor: mx.array, axis: int, rank: int, ratio: float) -> mx.array: + """Get rank's portion of an asymmetric split.""" + parts = _split_at(tensor, axis, ratio) + return parts[0] if rank == 0 else parts[1] + + +def _shard_quantized_ats( + layer: Any, + axis: int, + rank: int, + ratio: float, + segments: list[int] | None = None, +) -> None: + """Shard quantized linear all-to-sharded (output dim split).""" + if segments is not None: + w: mx.array = layer.weight + seg_parts_w = mx.split(w, segments, axis=axis) + my_w_parts = [_my_shard(p, axis, rank, ratio) for p in seg_parts_w] + layer.weight = mx.contiguous(mx.concatenate(my_w_parts, axis=axis)) + for attr in ["scales", "biases"]: + t: mx.array | None = getattr(layer, attr, None) + if t is None: + continue + t_seg = [int(s * t.shape[axis] / w.shape[axis]) for s in segments] + t_parts = mx.split(t, t_seg, axis=axis) + my_parts = [_my_shard(p, axis, rank, ratio) for p in t_parts] + setattr(layer, attr, mx.contiguous(mx.concatenate(my_parts, axis=axis))) + else: + for attr in ["weight", "scales", "biases"]: + t_val: mx.array | None = getattr(layer, attr, None) + if t_val is None: + continue + setattr(layer, attr, _my_shard(t_val, axis, rank, ratio)) + + +def _shard_quantized_sta(layer: Any, rank: int, ratio: float) -> None: + """Shard quantized linear sharded-to-all (input dim, axis -1).""" + for attr in ["weight", "scales", "biases"]: + t: mx.array | None = getattr(layer, attr, None) + if t is None: + continue + setattr(layer, attr, _my_shard(t, -1, rank, ratio)) + + +def _shard_gated_delta_net( + gdn: GatedDeltaNet, rank: int, ratio: float, group: mx.distributed.Group +) -> None: + """Asymmetric shard for GatedDeltaNet (linear attention) layers.""" + kd = gdn.key_dim + _shard_quantized_ats(gdn.in_proj_qkv, 0, rank, ratio, segments=[kd, 2 * kd]) + _shard_quantized_ats(gdn.in_proj_z, 0, rank, ratio) + _shard_quantized_ats(gdn.in_proj_b, 0, rank, ratio) + _shard_quantized_ats(gdn.in_proj_a, 0, rank, ratio) + _shard_quantized_sta(gdn.out_proj, rank, ratio) + + # conv1d: segmented split along channel dim + conv_w = gdn.conv1d.weight + seg_parts = mx.split(conv_w, [kd, 2 * kd], axis=0) + my_parts = [_my_shard(p, 0, rank, ratio) for p in seg_parts] + gdn.conv1d.weight = mx.contiguous(mx.concatenate(my_parts, axis=0)) + + gdn.dt_bias = _my_shard(gdn.dt_bias, 0, rank, ratio) + gdn.A_log = _my_shard(gdn.A_log, 0, rank, ratio) + + r = ratio if rank == 0 else (1 - ratio) + gdn.num_k_heads = int(gdn.num_k_heads * r) + gdn.num_v_heads = int(gdn.num_v_heads * r) + gdn.key_dim = int(gdn.key_dim * r) + gdn.value_dim = int(gdn.value_dim * r) + gdn.conv_dim = int(gdn.conv_dim * r) + gdn.conv1d.groups = gdn.conv_dim + gdn.sharding_group = group + + +# Patching must happen at the class level since nn.Module.__call__ ignores instance overrides +_attention_class_patched: set[type] = set() + + +def _patch_attention_class(attn_cls: type) -> None: + """Patch an attention class to add all_sum when _asymmetric_tp_group is set.""" + if attn_cls in _attention_class_patched: + return + + original_call = attn_cls.__call__ + + def patched_call( + self: nn.Module, + x: mx.array, + mask: mx.array | None = None, + cache: object | None = None, + ) -> mx.array: + result = original_call(self, x, mask=mask, cache=cache) + grp = getattr(self, "_asymmetric_tp_group", None) + if grp is not None: + result = mx.distributed.all_sum(result, group=grp) + return result + + attn_cls.__call__ = patched_call + _attention_class_patched.add(attn_cls) + + +def _shard_attention( + attn: Attention, rank: int, ratio: float, group: mx.distributed.Group +) -> None: + """Asymmetric shard for self-attention layers.""" + _patch_attention_class(type(attn)) + _shard_quantized_ats(attn.q_proj, 0, rank, ratio) + _shard_quantized_sta(attn.o_proj, rank, ratio) + # k_proj, v_proj: replicated (too few KV heads to split in most models) + + r = ratio if rank == 0 else (1 - ratio) + attn.num_attention_heads = int(attn.num_attention_heads * r) + attn._asymmetric_tp_group = group + + + +def _shard_sparse_moe( + moe: SparseMoeBlock | Qwen3_5SparseMoeBlock, + rank: int, + ratio: float, + group: mx.distributed.Group, +) -> None: + """Asymmetric shard for SparseMoeBlock (MoE layers).""" + # switch_mlp: split expert intermediate dims (axis 1 for 3D expert weights) + _shard_quantized_ats(moe.switch_mlp.gate_proj, 1, rank, ratio) + _shard_quantized_ats(moe.switch_mlp.up_proj, 1, rank, ratio) + _shard_quantized_sta(moe.switch_mlp.down_proj, rank, ratio) + + # shared_expert: standard MLP split + _shard_quantized_ats(moe.shared_expert.gate_proj, 0, rank, ratio) + _shard_quantized_ats(moe.shared_expert.up_proj, 0, rank, ratio) + _shard_quantized_sta(moe.shared_expert.down_proj, rank, ratio) + + # gate + shared_expert_gate: replicated (tiny routing weights) + moe.sharding_group = group + + +_mlp_class_patched: set[type] = set() + + +def _patch_mlp_class(mlp_cls: type) -> None: + """Patch a dense MLP class to add all_sum when _asymmetric_tp_group is set.""" + if mlp_cls in _mlp_class_patched: + return + + original_call = mlp_cls.__call__ + + def patched_call(self: nn.Module, x: mx.array) -> mx.array: + result = original_call(self, x) + grp = getattr(self, "_asymmetric_tp_group", None) + if grp is not None: + result = mx.distributed.all_sum(result, group=grp) + return result + + mlp_cls.__call__ = patched_call + _mlp_class_patched.add(mlp_cls) + + +def _shard_dense_mlp( + mlp: Qwen3NextMLP, rank: int, ratio: float, group: mx.distributed.Group +) -> None: + """Asymmetric shard for dense (non-MoE) MLP layers.""" + _patch_mlp_class(type(mlp)) + _shard_quantized_ats(mlp.gate_proj, 0, rank, ratio) + _shard_quantized_ats(mlp.up_proj, 0, rank, ratio) + _shard_quantized_sta(mlp.down_proj, rank, ratio) + mlp._asymmetric_tp_group = group + + +def asymmetric_tensor_auto_parallel( + model: nn.Module, + group: mx.distributed.Group, + ratios: list[float], +) -> nn.Module: + """ + Apply asymmetric tensor parallelism to a model. + + Args: + model: The model to parallelize (must have .layers property) + group: MLX distributed group + ratios: Per-rank weight fractions, e.g. [0.75, 0.25] for 2 nodes. + ratios[group.rank()] is this node's fraction. + + Returns: + The model with asymmetric sharding applied. + """ + rank = group.rank() + ratio = ratios[0] # ratio for rank 0; rank 1 gets 1-ratio + + # Get the inner model's layers + inner = model + for attr in ["language_model", "model"]: + candidate = getattr(inner, attr, None) + if candidate is not None and hasattr(candidate, "layers"): + inner = candidate + + layers: list[Any] = inner.layers if hasattr(inner, "layers") else model.layers + + for layer in layers: + if isinstance(layer, Qwen3_5DecoderLayer): + # Qwen3.5 hybrid: linear_attn or self_attn per layer + if layer.is_linear: + _shard_gated_delta_net(layer.linear_attn, rank, ratio, group) + else: + _shard_attention(layer.self_attn, rank, ratio, group) + + mlp = layer.mlp + if isinstance(mlp, (SparseMoeBlock, Qwen3_5SparseMoeBlock)): + _shard_sparse_moe(mlp, rank, ratio, group) + else: + _shard_dense_mlp(mlp, rank, ratio, group) + else: + raise ValueError( + f"Asymmetric TP does not yet support layer type {type(layer).__name__}. " + f"Currently supported: Qwen3.5 (GatedDeltaNet + Attention + MoE). " + f"Contributions for other architectures welcome." + ) + + logger.info( + f"Asymmetric TP applied: rank {rank} gets " + f"{ratios[rank] * 100:.0f}% of each weight tensor" + ) + return model diff --git a/src/exo/worker/engines/mlx/utils_mlx.py b/src/exo/worker/engines/mlx/utils_mlx.py index 790dcd8d3b..ef58c2d464 100644 --- a/src/exo/worker/engines/mlx/utils_mlx.py +++ b/src/exo/worker/engines/mlx/utils_mlx.py @@ -52,6 +52,7 @@ MlxRingInstance, ) from exo.shared.types.worker.shards import ( + AsymmetricTensorShardMetadata, CfgShardMetadata, PipelineShardMetadata, ShardMetadata, @@ -66,6 +67,9 @@ pipeline_auto_parallel, tensor_auto_parallel, ) +from exo.worker.engines.mlx.asymmetric_parallel import ( + asymmetric_tensor_auto_parallel, +) from exo.worker.runner.bootstrap import logger Group = mx.distributed.Group @@ -275,6 +279,17 @@ def shard_and_load( model = tensor_auto_parallel( model, group, timeout_seconds, on_timeout, on_layer_loaded ) + case AsymmetricTensorShardMetadata(): + # ratio is rank-0's share; same value on every shard so all + # ranks agree on the split point. + rank0_ratio = shard_metadata.ratio + ratios_list = [rank0_ratio, 1.0 - rank0_ratio] + logger.info( + f"loading model from {model_path} with asymmetric tensor parallelism " + f"(ratios={[f'{r:.0%}' for r in ratios_list]})" + ) + model = asymmetric_tensor_auto_parallel(model, group, ratios_list) + eval_with_timeout(model.parameters(), timeout_seconds, on_timeout) case PipelineShardMetadata(): logger.info(f"loading model from {model_path} with pipeline parallelism") model = pipeline_auto_parallel( diff --git a/src/exo/worker/tests/unittests/test_mlx/test_asymmetric_parallel.py b/src/exo/worker/tests/unittests/test_mlx/test_asymmetric_parallel.py new file mode 100644 index 0000000000..b775f239bd --- /dev/null +++ b/src/exo/worker/tests/unittests/test_mlx/test_asymmetric_parallel.py @@ -0,0 +1,104 @@ +"""Tests for asymmetric tensor parallelism ratio finding and sharding.""" + + + +class TestFindValidRatios: + """Test the ratio solver that finds valid asymmetric split points.""" + + def test_qwen3_5_122b_dimensions(self) -> None: + from exo.worker.engines.mlx.asymmetric_parallel import find_valid_ratios + + ratios = find_valid_ratios( + memory_fractions=[0.73, 0.27], + hidden_size=3072, + num_attention_heads=32, + num_key_value_heads=2, + linear_num_value_heads=64, + linear_num_key_heads=16, + moe_intermediate_size=1024, + num_experts=256, + ) + assert ratios is not None + assert len(ratios) == 2 + assert abs(ratios[0] + ratios[1] - 1.0) < 1e-10 + # All head counts must be exact integers after split + assert 32 * ratios[0] == int(32 * ratios[0]) # attention heads + assert 64 * ratios[0] == int(64 * ratios[0]) # value heads + assert 16 * ratios[0] == int(16 * ratios[0]) # key heads + + def test_llama_70b_dimensions(self) -> None: + from exo.worker.engines.mlx.asymmetric_parallel import find_valid_ratios + + ratios = find_valid_ratios( + memory_fractions=[0.73, 0.27], + hidden_size=8192, + num_attention_heads=64, + num_key_value_heads=8, + ) + assert ratios is not None + assert 64 * ratios[0] == int(64 * ratios[0]) + + def test_nemotron_120b_dimensions(self) -> None: + from exo.worker.engines.mlx.asymmetric_parallel import find_valid_ratios + + ratios = find_valid_ratios( + memory_fractions=[0.73, 0.27], + hidden_size=4096, + num_attention_heads=32, + num_key_value_heads=2, + ) + assert ratios is not None + assert 32 * ratios[0] == int(32 * ratios[0]) + + def test_rejects_impossible_dimensions(self) -> None: + """Prime-number head count with no valid fractional split.""" + from exo.worker.engines.mlx.asymmetric_parallel import find_valid_ratios + + ratios = find_valid_ratios( + memory_fractions=[0.73, 0.27], + hidden_size=3072, + num_attention_heads=7, # prime — can't split into 2 integer parts > 0.5 + num_key_value_heads=2, + ) + assert ratios is None + + def test_only_two_nodes_supported(self) -> None: + from exo.worker.engines.mlx.asymmetric_parallel import find_valid_ratios + + ratios = find_valid_ratios( + memory_fractions=[0.5, 0.25, 0.25], + hidden_size=4096, + num_attention_heads=32, + num_key_value_heads=8, + ) + assert ratios is None + + def test_ratio_closer_to_target(self) -> None: + """Ratio should be the closest valid one to the memory fraction.""" + from exo.worker.engines.mlx.asymmetric_parallel import find_valid_ratios + + # With 80% target, 0.8125 (13/16) is closer than 0.75 (12/16) + ratios = find_valid_ratios( + memory_fractions=[0.80, 0.20], + hidden_size=3072, + num_attention_heads=32, + num_key_value_heads=2, + linear_num_value_heads=64, + linear_num_key_heads=16, + ) + assert ratios is not None + assert abs(ratios[0] - 0.80) < abs(0.75 - 0.80) + + def test_equal_memory_returns_near_symmetric(self) -> None: + """When memory is roughly equal, ratio should be close to 0.5.""" + from exo.worker.engines.mlx.asymmetric_parallel import find_valid_ratios + + ratios = find_valid_ratios( + memory_fractions=[0.50, 0.50], + hidden_size=4096, + num_attention_heads=32, + num_key_value_heads=8, + ) + # Finder searches > 0.5, so it may find a near-symmetric split + if ratios is not None: + assert ratios[0] < 0.6 # should be close to 0.5