-
Notifications
You must be signed in to change notification settings - Fork 3.1k
Expand file tree
/
Copy pathplacement.py
More file actions
377 lines (334 loc) · 13.4 KB
/
placement.py
File metadata and controls
377 lines (334 loc) · 13.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
import random
from collections.abc import Mapping
from copy import deepcopy
from typing import Sequence
from loguru import logger
from exo.master.placement_utils import (
Cycle,
filter_cycles_by_memory,
get_mlx_jaccl_coordinators,
get_mlx_jaccl_devices_matrix,
get_mlx_ring_hosts_by_node,
get_shard_assignments,
get_smallest_cycles,
)
from exo.shared.models.model_cards import ModelId
from exo.shared.topology import Topology
from exo.shared.types.commands import (
CancelDownload,
CreateInstance,
DeleteInstance,
DownloadCommand,
PlaceInstance,
)
from exo.shared.types.common import NodeId
from exo.shared.types.events import (
Event,
InstanceCreated,
InstanceDeleted,
TaskStatusUpdated,
)
from exo.shared.types.memory import Memory
from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.worker.downloads import (
DownloadCompleted,
DownloadFailed,
DownloadOngoing,
DownloadPending,
DownloadProgress,
)
from exo.shared.types.worker.instances import (
Instance,
InstanceId,
InstanceMeta,
MlxJacclInstance,
MlxRingInstance,
)
from exo.shared.types.worker.shards import Sharding
def random_ephemeral_port() -> int:
port = random.randint(49153, 65535)
return port - 1 if port <= 52415 else port
def add_instance_to_placements(
command: CreateInstance,
topology: Topology,
current_instances: Mapping[InstanceId, Instance],
) -> Mapping[InstanceId, Instance]:
# TODO: validate against topology
return {**current_instances, command.instance.instance_id: command.instance}
def _get_node_download_fraction(
node_id: NodeId,
model_id: ModelId,
download_status: Mapping[NodeId, Sequence[DownloadProgress]],
) -> float:
"""Return the download fraction (0.0–1.0) for a model on a given node."""
for progress in download_status.get(node_id, []):
if progress.shard_metadata.model_card.model_id != model_id:
continue
match progress:
case DownloadCompleted():
return 1.0
case DownloadOngoing():
total = progress.download_progress.total.in_bytes
return (
progress.download_progress.downloaded.in_bytes / total
if total > 0
else 0.0
)
case DownloadPending():
total = progress.total.in_bytes
return progress.downloaded.in_bytes / total if total > 0 else 0.0
case DownloadFailed():
return 0.0
return 0.0
def _cycle_download_score(
cycle: Cycle,
model_id: ModelId,
download_status: Mapping[NodeId, Sequence[DownloadProgress]],
) -> float:
"""Sum of download fractions across all nodes in a cycle."""
return sum(
_get_node_download_fraction(node_id, model_id, download_status)
for node_id in cycle
)
def place_instance(
command: PlaceInstance,
topology: Topology,
current_instances: Mapping[InstanceId, Instance],
node_memory: Mapping[NodeId, MemoryUsage],
node_network: Mapping[NodeId, NodeNetworkInfo],
required_nodes: set[NodeId] | None = None,
download_status: Mapping[NodeId, Sequence[DownloadProgress]] | None = None,
) -> dict[InstanceId, Instance]:
cycles = topology.get_cycles()
candidate_cycles = list(filter(lambda it: len(it) >= command.min_nodes, cycles))
# Filter to cycles containing all required nodes (subset matching)
if required_nodes:
candidate_cycles = [
cycle
for cycle in candidate_cycles
if required_nodes.issubset(cycle.node_ids)
]
cycles_with_sufficient_memory = filter_cycles_by_memory(
candidate_cycles, node_memory, command.model_card.storage_size
)
if len(cycles_with_sufficient_memory) == 0:
raise ValueError("No cycles found with sufficient memory")
# Asymmetric TP currently only supports Qwen3.5 at the worker level.
asymmetric_tp_families = {"qwen3_5"}
if (
command.sharding == Sharding.AsymmetricTensor
and command.model_card.family not in asymmetric_tp_families
):
raise ValueError(
f"Asymmetric tensor parallelism is not yet supported for "
f"family '{command.model_card.family}'. "
f"Supported: {asymmetric_tp_families}"
)
if command.sharding in (Sharding.Tensor, Sharding.AsymmetricTensor):
if not command.model_card.supports_tensor:
raise ValueError(
f"Requested Tensor sharding but this model does not support tensor parallelism: {command.model_card.model_id}"
)
if command.sharding == Sharding.Tensor:
# TODO: the condition here for tensor parallel is not correct, but it works good enough for now.
kv_heads = command.model_card.num_key_value_heads
cycles_with_sufficient_memory = [
cycle
for cycle in cycles_with_sufficient_memory
if command.model_card.hidden_size % len(cycle) == 0
and (kv_heads is None or kv_heads % len(cycle) == 0)
]
if not cycles_with_sufficient_memory:
raise ValueError(
f"No tensor sharding found for model with "
f"hidden_size={command.model_card.hidden_size}"
f"{f', num_key_value_heads={kv_heads}' if kv_heads is not None else ''}"
f" across candidate cycles"
)
# Auto-upgrade to AsymmetricTensor when equal TP won't fit on
# the smallest node but asymmetric split would.
# Only for model families with tested asymmetric TP support.
if command.model_card.family in asymmetric_tp_families:
for cycle in cycles_with_sufficient_memory:
equal_share = command.model_card.storage_size.in_bytes / len(cycle)
min_node_mem = min(
node_memory[nid].ram_available.in_bytes for nid in cycle
)
if equal_share > min_node_mem * 0.9:
# Equal split too tight — try asymmetric
total_mem = sum(
node_memory[nid].ram_available.in_bytes
for nid in cycle
)
if (
command.model_card.storage_size.in_bytes
< total_mem * 0.85
):
logger.info(
"Equal tensor split won't fit on smallest node "
f"({min_node_mem / 1e9:.0f}GB available, "
f"needs {equal_share / 1e9:.0f}GB). "
"Auto-upgrading to AsymmetricTensor."
)
command.sharding = Sharding.AsymmetricTensor
break
if command.sharding == Sharding.Pipeline and command.model_card.model_id == ModelId(
"mlx-community/DeepSeek-V3.1-8bit"
):
raise ValueError(
"Pipeline parallelism is not supported for DeepSeek V3.1 (8-bit)"
)
smallest_cycles = get_smallest_cycles(cycles_with_sufficient_memory)
smallest_rdma_cycles = [
cycle for cycle in smallest_cycles if topology.is_rdma_cycle(cycle)
]
if command.instance_meta == InstanceMeta.MlxJaccl:
if not smallest_rdma_cycles:
raise ValueError(
"Requested RDMA (MlxJaccl) but no RDMA-connected cycles available"
)
smallest_cycles = smallest_rdma_cycles
cycles_with_leaf_nodes: list[Cycle] = [
cycle
for cycle in smallest_cycles
if any(topology.node_is_leaf(node_id) for node_id in cycle)
]
resolved_download_status = download_status or {}
candidate_cycles = (
cycles_with_leaf_nodes if cycles_with_leaf_nodes != [] else smallest_cycles
)
selected_cycle = max(
candidate_cycles,
key=lambda cycle: (
_cycle_download_score(
cycle, command.model_card.model_id, resolved_download_status
),
sum(
(node_memory[node_id].ram_available for node_id in cycle),
start=Memory(),
),
),
)
# Single-node: force Pipeline/Ring (Tensor and Jaccl require multi-node)
if len(selected_cycle) == 1:
command.instance_meta = InstanceMeta.MlxRing
command.sharding = Sharding.Pipeline
shard_assignments = get_shard_assignments(
command.model_card, selected_cycle, command.sharding, node_memory
)
cycle_digraph: Topology = topology.get_subgraph_from_nodes(selected_cycle.node_ids)
instance_id = InstanceId()
target_instances = dict(deepcopy(current_instances))
match command.instance_meta:
case InstanceMeta.MlxJaccl:
# TODO(evan): shard assignments should contain information about ranks, this is ugly
def get_device_rank(node_id: NodeId) -> int:
runner_id = shard_assignments.node_to_runner[node_id]
shard_metadata = shard_assignments.runner_to_shard.get(runner_id)
assert shard_metadata is not None
return shard_metadata.device_rank
zero_node_ids = [
node_id
for node_id in selected_cycle.node_ids
if get_device_rank(node_id) == 0
]
assert len(zero_node_ids) == 1
coordinator_node_id = zero_node_ids[0]
mlx_jaccl_devices = get_mlx_jaccl_devices_matrix(
[node_id for node_id in selected_cycle],
cycle_digraph,
)
mlx_jaccl_coordinators = get_mlx_jaccl_coordinators(
coordinator=coordinator_node_id,
coordinator_port=random_ephemeral_port(),
cycle_digraph=cycle_digraph,
node_network=node_network,
)
target_instances[instance_id] = MlxJacclInstance(
instance_id=instance_id,
shard_assignments=shard_assignments,
jaccl_devices=mlx_jaccl_devices,
jaccl_coordinators=mlx_jaccl_coordinators,
)
case InstanceMeta.MlxRing:
ephemeral_port = random_ephemeral_port()
hosts_by_node = get_mlx_ring_hosts_by_node(
selected_cycle=selected_cycle,
cycle_digraph=cycle_digraph,
ephemeral_port=ephemeral_port,
node_network=node_network,
)
target_instances[instance_id] = MlxRingInstance(
instance_id=instance_id,
shard_assignments=shard_assignments,
hosts_by_node=hosts_by_node,
ephemeral_port=ephemeral_port,
)
return target_instances
def delete_instance(
command: DeleteInstance,
current_instances: Mapping[InstanceId, Instance],
) -> dict[InstanceId, Instance]:
target_instances = dict(deepcopy(current_instances))
if command.instance_id in target_instances:
del target_instances[command.instance_id]
return target_instances
raise ValueError(f"Instance {command.instance_id} not found")
def get_transition_events(
current_instances: Mapping[InstanceId, Instance],
target_instances: Mapping[InstanceId, Instance],
tasks: Mapping[TaskId, Task],
) -> Sequence[Event]:
events: list[Event] = []
# find instances to create
for instance_id, instance in target_instances.items():
if instance_id not in current_instances:
events.append(
InstanceCreated(
instance=instance,
)
)
# find instances to delete
for instance_id in current_instances:
if instance_id not in target_instances:
for task in tasks.values():
if task.instance_id == instance_id and task.task_status in [
TaskStatus.Pending,
TaskStatus.Running,
]:
events.append(
TaskStatusUpdated(
task_status=TaskStatus.Cancelled,
task_id=task.task_id,
)
)
events.append(
InstanceDeleted(
instance_id=instance_id,
)
)
return events
def cancel_unnecessary_downloads(
instances: Mapping[InstanceId, Instance],
download_status: Mapping[NodeId, Sequence[DownloadProgress]],
) -> Sequence[DownloadCommand]:
commands: list[DownloadCommand] = []
currently_downloading = [
(k, v.shard_metadata.model_card.model_id)
for k, vs in download_status.items()
for v in vs
if isinstance(v, (DownloadOngoing))
]
active_models = set(
(
node_id,
instance.shard_assignments.runner_to_shard[runner_id].model_card.model_id,
)
for instance in instances.values()
for node_id, runner_id in instance.shard_assignments.node_to_runner.items()
)
for pair in currently_downloading:
if pair not in active_models:
commands.append(CancelDownload(target_node_id=pair[0], model_id=pair[1]))
return commands