forked from vllm-project/tpu-inference
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathray_distributed_executor.py
More file actions
426 lines (361 loc) · 18.4 KB
/
ray_distributed_executor.py
File metadata and controls
426 lines (361 loc) · 18.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import os
from typing import Dict, List, Optional
import ray
import vllm.envs as envs
from ray.util.placement_group import PlacementGroup
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from vllm.platforms import current_platform
from vllm.ray.ray_env import get_env_vars_to_copy
from vllm.utils.network_utils import (get_distributed_init_method, get_ip,
get_open_port)
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.executor.ray_distributed_executor import \
RayDistributedExecutor as RayDistributedExecutorV1
from vllm.v1.executor.ray_executor import RayWorkerMetaData
from vllm.v1.executor.ray_utils import RayWorkerWrapper as RayWorkerWrapperV1
from vllm.v1.executor.ray_utils import _wait_until_pg_ready
from tpu_inference.distributed.jax_parallel_state import get_pp_group
from tpu_inference.logger import init_logger
from tpu_inference.models.jax.jax_intermediate_tensor import \
JaxIntermediateTensors
try:
from ray._private.state import available_resources_per_node
except ImportError:
# Ray 2.9.x doesn't expose `available_resources_per_node`
from ray._private.state import state as _state
available_resources_per_node = _state._available_resources_per_node
import asyncio
from collections import defaultdict
from tpu_inference.distributed.utils import set_node_kv_ip_port
logger = init_logger(__name__)
class RayDistributedExecutor(RayDistributedExecutorV1):
"""Ray-based distributed executor for TPU.
The implementation is similar to vllm/executor/ray_distributed_executor.py
with these major differences:
1. self._init_executor():
VLLM_USE_RAY_SPMD_WORKER=1, in which the driver worker is the same as other workers.
2. self._initialize_ray_cluster():
This sets placement_group_specs for TPU.
In vLLM one GPU maps to one placement group.
While here one TPU node with all chips maps to one placement group.
3. self._init_workers_ray():
This set TPU resources when create each worker.
And we omit the driver worker related logic.
"""
def _init_executor(self) -> None:
self.forward_dag: Optional[ray.dag.CompiledDAG] = None
os.environ["VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE"] = "shm"
# Currently, this requires USE_RAY_SPMD_WORKER=True.
self.use_ray_compiled_dag = True
# If it is true, then we do not distinguish between the
# "driver worker" vs other workers. Also, the rank 0 worker will
# be executed in a remote Ray worker. Currently this requires
# USE_RAY_COMPILED_DAG=True.
self.use_ray_spmd_worker = True
assert self.uses_ray
self._initialize_ray_cluster()
placement_group = self.parallel_config.placement_group
# Disable Ray usage stats collection.
ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
if ray_usage != "1":
os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
# Create the parallel GPU workers.
self._init_workers_ray(placement_group)
self.pp_locks: Optional[List[asyncio.Lock]] = None
self.scheduler_output: SchedulerOutput | None = None
# KV connector setup
self.has_connector = self.vllm_config.kv_transfer_config is not None
if self.has_connector:
ip_port = self.collective_rpc("get_node_kv_ip_port")
for item in ip_port:
set_node_kv_ip_port(item)
self.uses_sampler = self.vllm_config.model_config.runner_type != "pooling" and (
self.vllm_config.ec_transfer_config is None
or not self.vllm_config.ec_transfer_config.is_ec_producer)
def _initialize_ray_cluster(self) -> None:
"""Initialize the distributed cluster with Ray.
it will connect to the Ray cluster and create a placement group
for the workers, which includes the specification of the resources
for each distributed worker.
"""
from vllm.platforms import current_platform
if ray.is_initialized():
logger.info(
"Ray is already initialized. Skipping Ray initialization.")
else:
logger.warning("Ray is not initialized, this is mainly for test.")
ray.init()
device_str = current_platform.ray_device_key
if not device_str:
raise ValueError(
f"current platform {current_platform.device_name} does not "
"support ray.")
pp_size = self.parallel_config.pipeline_parallel_size
placement_group_specs: List[Dict[str, float]] = []
ray_nodes = ray.nodes()
logger.info(f"RayDistributedExecutor | ray_nodes={ray_nodes}")
if pp_size == 1:
placement_group_specs = [{
device_str: node['Resources'][device_str]
} for node in ray_nodes]
else:
assert pp_size == len(
ray_nodes
), f"Cannot use PP across hosts, please set --pipeline-parallel-size to 1 or {len(ray_nodes)}"
num_devices_per_pp_rank = self.vllm_config.sharding_config.total_devices
placement_group_specs = [{
device_str: num_devices_per_pp_rank
} for _ in range(pp_size)]
# vLLM engine is also a worker to execute model with an accelerator,
# so it requires to have the device in a current node. Check if
# the current node has at least one device.
current_ip = get_ip()
current_node_id = ray.get_runtime_context().get_node_id()
current_node_resource = available_resources_per_node()[current_node_id]
if current_node_resource.get(device_str, 0) < 1:
raise ValueError(
f"Current node has no {device_str} available. "
f"{current_node_resource=}. vLLM engine cannot start without "
f"{device_str}. Make sure you have at least 1 {device_str} "
f"available in a node {current_node_id=} {current_ip=}.")
# This way, at least bundle is required to be created in a current
# node.
placement_group_specs[0][f"node:{current_ip}"] = 0.001
logger.info(
f"RayDistributedExecutor | placement_group_specs={placement_group_specs}"
)
# Use STRICT_SPREAD for PP to ensure each host participates in JAX initialization.
strategy = "STRICT_SPREAD" if pp_size > 1 else "PACK"
current_placement_group = ray.util.placement_group(
placement_group_specs, strategy=strategy)
_wait_until_pg_ready(current_placement_group)
assert current_placement_group is not None
# Set the placement group in the parallel config
self.parallel_config.placement_group = current_placement_group
def _init_workers_ray(self, placement_group: "PlacementGroup",
**ray_remote_kwargs):
# The workers are the actual ray actors.
self.workers: List[RayWorkerWrapper] = []
# Used in ray compiled DAG: indexed first by PP rank,
# and then TP rank. In other words, the inner list is
# the TP group of workers for a PP rank.
self.pp_tp_workers: List[List[RayWorkerWrapper]] = []
if self.parallel_config.ray_workers_use_nsight:
ray_remote_kwargs = self._configure_ray_workers_use_nsight(
ray_remote_kwargs)
# Create the workers.
bundle_indices: List[int]
if envs.VLLM_RAY_BUNDLE_INDICES:
# Use the bundle indices specified by the user.
bundle_indices = list(
map(int, envs.VLLM_RAY_BUNDLE_INDICES.split(",")))
assert len(bundle_indices) == self.parallel_config.world_size, \
("VLLM_RAY_BUNDLE_INDICES must have the same size"
f" as the world size, but got {bundle_indices=} "
f"and {self.parallel_config.world_size=}")
assert len(set(bundle_indices)) == len(bundle_indices), \
("VLLM_RAY_BUNDLE_INDICES cannot have duplicate values,"
f" but got {bundle_indices=}")
else:
bundle_indices = []
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
if bundle.get(current_platform.ray_device_key, 0):
bundle_indices.append(bundle_id)
worker_metadata: List[RayWorkerMetaData] = []
driver_ip = get_ip()
num_tpu_per_worker = placement_group.bundle_specs[0].get(
current_platform.ray_device_key, 0)
for rank, bundle_id in enumerate(bundle_indices):
scheduling_strategy = PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_capture_child_tasks=True,
placement_group_bundle_index=bundle_id,
)
worker = ray.remote(
num_cpus=0,
num_gpus=0,
resources={
current_platform.ray_device_key: num_tpu_per_worker
},
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerWrapper).remote(rpc_rank=rank)
worker_metadata.append(
RayWorkerMetaData(worker=worker, created_rank=rank))
worker_ips = ray.get([
each.worker.get_node_ip.remote() # type: ignore[attr-defined]
for each in worker_metadata
])
for each, ip in zip(worker_metadata, worker_ips):
each.ip = ip
logger.debug(f"Initialized worker_metadata: {worker_metadata}")
ip_counts: Dict[str, int] = {}
for ip in worker_ips:
ip_counts[ip] = ip_counts.get(ip, 0) + 1
def sort_by_driver_then_worker_ip(item: RayWorkerMetaData):
"""
Sort the workers based on 3 properties:
1. If the worker is on the same node as the driver (vllm engine),
it should be placed first.
2. Then, if the worker is on a node with fewer workers, it should
be placed first.
3. Finally, if the work is on a node with smaller IP address, it
should be placed first.
"""
ip = item.ip
return (0 if ip == driver_ip else 1, ip_counts[ip], ip)
# After sorting, the workers on the same node will be
# close to each other, and the workers on the driver
# node will be placed first.
sorted_worker_metadata = sorted(worker_metadata,
key=sort_by_driver_then_worker_ip)
start_rank = 0
for i, item in enumerate(sorted_worker_metadata):
item.adjusted_rank = i + start_rank
logger.info(
f"Initialized sorted worker_metadata: {sorted_worker_metadata}")
self.workers = [item.worker for item in sorted_worker_metadata]
rerank_mapping = {
item.created_rank: item.adjusted_rank
for item in sorted_worker_metadata
}
self.collective_rpc("adjust_rank", args=(rerank_mapping, ))
# Get the set of TPU IDs used on each node.
worker_node_and_tpu_ids = []
for worker in self.workers:
worker_node_and_tpu_ids.append(
ray.get(worker.get_node_and_gpu_ids.remote()) \
) # type: ignore
node_workers = defaultdict(list) # node id -> list of worker ranks
node_tpus = defaultdict(list) # node id -> list of tpu ids
for i, (node_id, tpu_ids) in enumerate(worker_node_and_tpu_ids):
node_workers[node_id].append(i)
# `tpu_ids` can be a list of strings or integers.
# convert them to integers for consistency.
tpu_ids = [int(x) for x in tpu_ids]
node_tpus[node_id].extend(tpu_ids)
for node_id, tpu_ids in node_tpus.items():
node_tpus[node_id] = sorted(tpu_ids)
logger.info(
f"RayDistributedExecutor | node_workers={node_workers} | node_tpus={node_tpus}"
)
all_ips = set(worker_ips + [driver_ip])
n_ips = len(all_ips)
n_nodes = len(node_workers)
if n_nodes != n_ips:
logger.warning(
f"Got {n_nodes} nodes but with {n_ips} IP addresses. "
"This is not a typical production setup whose "
"number of nodes and IPs is euqal. This setup may "
"lead to unexpected behaviors.")
# Set environment variables for the driver and workers.
all_args_to_update_environment_variables = [{
current_platform.device_control_env_var:
",".join(map(str, node_tpus[node_id])),
} for (node_id, _) in worker_node_and_tpu_ids]
# Environment variables to copy from driver to workers
env_vars_to_copy = get_env_vars_to_copy(
exclude_vars=self.WORKER_SPECIFIC_ENV_VARS,
additional_vars=set(current_platform.additional_env_vars),
destination="workers")
# Copy existing env vars to each worker's args, but don't overwrite
# and don't copy global topology bounds from the driver.
for i, args in enumerate(all_args_to_update_environment_variables):
for name in env_vars_to_copy:
if name in os.environ:
args[name] = os.environ[name]
logger.debug(f"Worker {i} environment variables: {args}")
self._env_vars_for_all_workers = (
all_args_to_update_environment_variables)
self.collective_rpc("update_environment_variables",
args=(self._get_env_vars_to_be_updated(), ))
distributed_init_method = get_distributed_init_method(
driver_ip, get_open_port())
# Get the Driver's Node ID to identify local vs remote workers
driver_node_id = ray.get_runtime_context().get_node_id()
# Initialize the actual workers inside worker wrapper.
all_kwargs = []
for rank, (node_id, _) in enumerate(worker_node_and_tpu_ids):
local_rank = node_workers[node_id].index(rank)
ip = sorted_worker_metadata[rank].ip
prev_ip = sorted_worker_metadata[rank - 1].ip if rank > 0 else ""
worker_vllm_config = self.vllm_config
# When using object storage (e.g., RunAI), the Leader updates `model` to its local
# cache path (e.g., /root/.cache/...) during ModelConfig initialization
# (maybe_pull_model_tokenizer_for_runai), while `model_weights` preserves the original URI after the model is pulled.
# (Standard HF downloads do not overwrite `model`, allowing workers to pull normally).
# Since workers on remote nodes cannot access the Leader's filesystem, we create a
# worker-specific config copy and restore the original GCS URI from `model_weights`.
# This allows each worker to independently invoke `maybe_pull_model_tokenizer_for_runai`
# and stream the model from GCS.
if node_id != driver_node_id and getattr(
self.vllm_config, "model_config", None) and getattr(
self.vllm_config.model_config, "model_weights", None):
worker_vllm_config = copy.deepcopy(self.vllm_config)
worker_vllm_config.model_config.model = worker_vllm_config.model_config.model_weights
# Unset model_weights so maybe_pull_model_tokenizer_for_runai will pull the model.
worker_vllm_config.model_config.model_weights = None
kwargs = dict(
vllm_config=worker_vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
is_driver_worker=(not self.parallel_config)
or (rank % self.parallel_config.tensor_parallel_size == 0),
ip=ip,
prev_worker_ip=prev_ip,
)
# NOTE(Chenyaaang): Adjust worker's rank to 0 if PP=1.
# Otherwise if we have 4 ray nodes each with 1 chip and use TP=4,
# We'll have 4 workers with rank 0,1,2,3 respectively. This
# contradicts with get_pp_group().
if self.parallel_config.pipeline_parallel_size == 1:
kwargs["rank"] = 0
all_kwargs.append(kwargs)
self.collective_rpc("init_worker", args=(all_kwargs, ))
self.collective_rpc("init_device")
if self.parallel_config.pipeline_parallel_size > 1:
self.collective_rpc("initialize_pp_transfer_connect")
self.collective_rpc("load_model")
if self.use_ray_spmd_worker:
for pp_rank in range(self.parallel_config.pipeline_parallel_size):
self.pp_tp_workers.append([])
num_tp_workers = int(
len(self.workers) //
self.parallel_config.pipeline_parallel_size)
for tp_rank in range(num_tp_workers):
# PP=2, TP=4, num_tpu_per_worker=2
# pp_tp_workers = [[0, 1], [2, 3]]
rank = (pp_rank * num_tp_workers) + tp_rank
assert len(self.pp_tp_workers[pp_rank]) == tp_rank
assert pp_rank < len(self.pp_tp_workers)
self.pp_tp_workers[pp_rank].append(self.workers[rank])
# Ray executor do not need handshake metadata
# as we pass the kv_parameters through proxy server
def get_kv_connector_handshake_metadata(self) -> None:
pass
class RayWorkerWrapper(RayWorkerWrapperV1):
"""
Ray worker wrapper for TPU.
The implementation is similar to vllm/v1/executor/ray_utils.py
_is_intermediate_tensors: check whether the output is JaxIntermediateTensors.
_is_last_rank: check whether this Ray worker is the last PP stage.
"""
def _is_intermediate_tensors(self, output) -> bool:
return isinstance(output, JaxIntermediateTensors)
def _is_last_rank(self) -> bool:
return get_pp_group().is_last_rank