From 05b66b282b847247e0c5017398ba4551bbc33735 Mon Sep 17 00:00:00 2001 From: Youngkyu Choi Date: Tue, 24 Mar 2026 18:56:32 +0900 Subject: [PATCH 1/3] feat(vllm_rbln): enable P/D disaggregation with NIXL host KV transfer wire vLLM KV transfer to a RBLN-specific NIXL connector and host-side buffers so prefill/decode can run on separate engines with H2H transfer. KV connector / registration - add RblnNixlConnector (scheduler/worker) extending upstream NixlConnector: - register connector name "RblnNixlConnector" in kv_connector factory. Platform - expose NIXL hints: get_nixl_supported_devices (rbln -> cpu) and get_nixl_memory_type ("DRAM"). Scheduler (rbln_scheduler.py) - handle kv_consumer request to be scheduled with other requests in decode stage Model runner (rbln_model_runner.py) - override maybe_get_kv_connector_output(..., wait_for_save) using last prefill chunk. - replace generic copy_kv_blocks with rbln_copy_kv_blocks using runtime _update_kv_cache / _fetch_kv_cache - bind_kv_cache_name + per-layer names for mark_static_address when compiling. Attention backend (flash_attention.py) - Report backend name as FLASH_ATTN for upstream compatibility. Examples - add experimental examples/experimental/pd_disaggregation/toy_proxy_server.py (FastAPI proxy routing chat completions to prefill vs decode HTTP backends). --- .../pd_disaggregation/toy_proxy_server.py | 297 ++++++++++++++++++ vllm_rbln/__init__.py | 1 + .../kv_transfer/kv_connector/factory.py | 21 ++ .../kv_connector/v1/rbln_nixl_connector.py | 215 +++++++++++++ vllm_rbln/platform.py | 10 + .../v1/attention/backends/flash_attention.py | 20 +- vllm_rbln/v1/core/rbln_scheduler.py | 24 +- vllm_rbln/v1/worker/rbln_model_runner.py | 120 +++++-- vllm_rbln/v1/worker/rbln_worker.py | 26 +- vllm_rbln/v1/worker/utils.py | 53 ++++ 10 files changed, 735 insertions(+), 52 deletions(-) create mode 100644 examples/experimental/pd_disaggregation/toy_proxy_server.py create mode 100644 vllm_rbln/distributed/kv_transfer/kv_connector/factory.py create mode 100644 vllm_rbln/distributed/kv_transfer/kv_connector/v1/rbln_nixl_connector.py diff --git a/examples/experimental/pd_disaggregation/toy_proxy_server.py b/examples/experimental/pd_disaggregation/toy_proxy_server.py new file mode 100644 index 000000000..b8539a22b --- /dev/null +++ b/examples/experimental/pd_disaggregation/toy_proxy_server.py @@ -0,0 +1,297 @@ +# Copyright 2025 Rebellions Inc. All rights reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import argparse +import itertools +import logging +import os +import uuid +from contextlib import asynccontextmanager + +import httpx +from fastapi import FastAPI, Request +from fastapi.responses import StreamingResponse + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """ + Lifespan context manager to handle startup and shutdown events. + """ + # Startup: Initialize client pools for prefiller and decoder services + app.state.prefill_clients = [] + app.state.decode_clients = [] + + # Create prefill clients + for i, (host, port) in enumerate(global_args.prefiller_instances): + prefiller_base_url = f"http://{host}:{port}/v1" + app.state.prefill_clients.append( + { + "client": httpx.AsyncClient( + timeout=None, + base_url=prefiller_base_url, + limits=httpx.Limits( + max_connections=None, + max_keepalive_connections=None, + ), + ), + "host": host, + "port": port, + "id": i, + } + ) + + # Create decode clients + for i, (host, port) in enumerate(global_args.decoder_instances): + decoder_base_url = f"http://{host}:{port}/v1" + app.state.decode_clients.append( + { + "client": httpx.AsyncClient( + timeout=None, + base_url=decoder_base_url, + limits=httpx.Limits( + max_connections=None, + max_keepalive_connections=None, + ), + ), + "host": host, + "port": port, + "id": i, + } + ) + + # Initialize round-robin iterators + app.state.prefill_iterator = itertools.cycle(range(len(app.state.prefill_clients))) + app.state.decode_iterator = itertools.cycle(range(len(app.state.decode_clients))) + + print( + f"Initialized {len(app.state.prefill_clients)} prefill clients " + f"and {len(app.state.decode_clients)} decode clients." + ) + + yield + + # Shutdown: Close all clients + for client_info in app.state.prefill_clients: + await client_info["client"].aclose() + + for client_info in app.state.decode_clients: + await client_info["client"].aclose() + + +# Update FastAPI app initialization to use lifespan +app = FastAPI(lifespan=lifespan) + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--port", type=int, default=8000) + # Always use 127.0.0.1 as localhost binds to IPv6 which is blocked on CI + parser.add_argument("--host", type=str, default="127.0.0.1") + + # For prefiller instances + parser.add_argument( + "--prefiller-hosts", + "--prefiller-host", + type=str, + nargs="+", + default=["localhost"], + ) + parser.add_argument( + "--prefiller-ports", "--prefiller-port", type=int, nargs="+", default=[8100] + ) + + # For decoder instances + parser.add_argument( + "--decoder-hosts", "--decoder-host", type=str, nargs="+", default=["localhost"] + ) + parser.add_argument( + "--decoder-ports", "--decoder-port", type=int, nargs="+", default=[8200] + ) + + args = parser.parse_args() + + # Validate and pair hosts with ports + if len(args.prefiller_hosts) != len(args.prefiller_ports): + raise ValueError( + "Number of prefiller hosts must match number of prefiller ports" + ) + + if len(args.decoder_hosts) != len(args.decoder_ports): + raise ValueError("Number of decoder hosts must match number of decoder ports") + + # Create tuples of (host, port) for each service type + args.prefiller_instances = list(zip(args.prefiller_hosts, args.prefiller_ports)) + args.decoder_instances = list(zip(args.decoder_hosts, args.decoder_ports)) + + return args + + +def get_next_client(app, service_type: str): + """ + Get the next client in round-robin fashion. + + Args: + app: The FastAPI app instance + service_type: Either 'prefill' or 'decode' + + Returns: + The next client to use + """ + if service_type == "prefill": + client_idx = next(app.state.prefill_iterator) + return app.state.prefill_clients[client_idx] + elif service_type == "decode": + client_idx = next(app.state.decode_iterator) + return app.state.decode_clients[client_idx] + else: + raise ValueError(f"Unknown service type: {service_type}") + + +async def send_request_to_service( + client_info: dict, endpoint: str, req_data: dict, request_id: str +): + """ + Send a request to a service using a client from the pool. + """ + req_data = req_data.copy() + req_data["kv_transfer_params"] = { + "do_remote_decode": True, + "do_remote_prefill": False, + "remote_engine_id": None, + "remote_block_ids": None, + "remote_host": None, + "remote_port": None, + } + req_data["stream"] = False + req_data["max_tokens"] = 1 + if "max_completion_tokens" in req_data: + req_data["max_completion_tokens"] = 1 + if "stream_options" in req_data: + del req_data["stream_options"] + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "X-Request-Id": request_id, + } + + response = await client_info["client"].post( + endpoint, json=req_data, headers=headers + ) + response.raise_for_status() + + # read/consume the response body to release the connection + # otherwise, it would http.ReadError + await response.aread() + + return response + + +async def stream_service_response( + client_info: dict, endpoint: str, req_data: dict, request_id: str +): + """ + Asynchronously stream response from a service using a client from the pool. + """ + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "X-Request-Id": request_id, + } + + async with client_info["client"].stream( + "POST", endpoint, json=req_data, headers=headers + ) as response: + response.raise_for_status() + async for chunk in response.aiter_bytes(): + yield chunk + + +async def _handle_completions(api: str, request: Request): + try: + req_data = await request.json() + request_id = str(uuid.uuid4()) + + # Get the next prefill client in round-robin fashion + prefill_client_info = get_next_client(request.app, "prefill") + + # Send request to prefill service + response = await send_request_to_service( + prefill_client_info, api, req_data, request_id + ) + + # Extract the needed fields + response_json = response.json() + await response.aclose() # CRITICAL: Release connection back to pool + kv_transfer_params = response_json.get("kv_transfer_params", {}) + if kv_transfer_params: + req_data["kv_transfer_params"] = kv_transfer_params + + # Get the next decode client in round-robin fashion + decode_client_info = get_next_client(request.app, "decode") + + logger.debug("Using %s %s", prefill_client_info, decode_client_info) + + # Stream response from decode service + async def generate_stream(): + async for chunk in stream_service_response( + decode_client_info, api, req_data, request_id=request_id + ): + yield chunk + + return StreamingResponse(generate_stream(), media_type="application/json") + + except Exception as e: + import sys + import traceback + + exc_info = sys.exc_info() + print(f"Error occurred in disagg prefill proxy server - {api} endpoint") + print(e) + print("".join(traceback.format_exception(*exc_info))) + raise + + +@app.post("/v1/completions") +async def handle_completions(request: Request): + return await _handle_completions("/completions", request) + + +@app.post("/v1/chat/completions") +async def handle_chat_completions(request: Request): + return await _handle_completions("/chat/completions", request) + + +@app.get("/healthcheck") +async def healthcheck(): + """Simple endpoint to check if the server is running.""" + return { + "status": "ok", + "prefill_instances": len(app.state.prefill_clients), + "decode_instances": len(app.state.decode_clients), + } + + +if __name__ == "__main__": + global global_args + global_args = parse_args() + + import uvicorn + + uvicorn.run(app, host=global_args.host, port=global_args.port) diff --git a/vllm_rbln/__init__.py b/vllm_rbln/__init__.py index fb82d6390..7b6523a36 100644 --- a/vllm_rbln/__init__.py +++ b/vllm_rbln/__init__.py @@ -45,6 +45,7 @@ def register_model(): def register_ops(): if envs.VLLM_RBLN_USE_VLLM_MODEL: import vllm_rbln.attention.layer # noqa + import vllm_rbln.distributed.kv_transfer.kv_connector.factory # noqa import vllm_rbln.forward_context # noqa import vllm_rbln.lora.layer # noqa import vllm_rbln.model_executor.layers.fused_moe.layer # noqa diff --git a/vllm_rbln/distributed/kv_transfer/kv_connector/factory.py b/vllm_rbln/distributed/kv_transfer/kv_connector/factory.py new file mode 100644 index 000000000..bdbc8c8a7 --- /dev/null +++ b/vllm_rbln/distributed/kv_transfer/kv_connector/factory.py @@ -0,0 +1,21 @@ +# Copyright 2025 Rebellions Inc. All rights reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory + +KVConnectorFactory.register_connector( + "RblnNixlConnector", + "vllm_rbln.distributed.kv_transfer.kv_connector.v1.rbln_nixl_connector", + "RblnNixlConnector", +) diff --git a/vllm_rbln/distributed/kv_transfer/kv_connector/v1/rbln_nixl_connector.py b/vllm_rbln/distributed/kv_transfer/kv_connector/v1/rbln_nixl_connector.py new file mode 100644 index 000000000..af9dcd2bb --- /dev/null +++ b/vllm_rbln/distributed/kv_transfer/kv_connector/v1/rbln_nixl_connector.py @@ -0,0 +1,215 @@ +# Copyright 2025 Rebellions Inc. All rights reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from typing import TYPE_CHECKING, Any, Optional + +import torch +from rebel.kv_cache import aligned_tensor +from vllm import envs +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + CopyBlocksOp, + KVConnectorBase_V1, + KVConnectorMetadata, + KVConnectorRole, +) +from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( + EngineId, + NixlConnector, + NixlConnectorMetadata, + NixlConnectorScheduler, + NixlConnectorWorker, +) +from vllm.v1.core.sched.output import SchedulerOutput + +from vllm_rbln.logger import init_logger + +if TYPE_CHECKING: + from vllm.v1.kv_cache_interface import KVCacheConfig + from vllm.v1.request import Request + +logger = init_logger(__name__) + + +class RblnNixlConnector(NixlConnector): + def __init__( + self, + vllm_config: VllmConfig, + role: KVConnectorRole, + kv_cache_config: Optional["KVCacheConfig"] = None, + ): + KVConnectorBase_V1.__init__(self, vllm_config, role, kv_cache_config) + + assert vllm_config.kv_transfer_config is not None + assert vllm_config.kv_transfer_config.engine_id is not None + self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id + + if role == KVConnectorRole.SCHEDULER: + self.connector_scheduler: RblnNixlConnectorScheduler | None = ( + RblnNixlConnectorScheduler(vllm_config, self.engine_id) + ) + self.connector_worker: RblnNixlConnectorWorker | None = None + elif role == KVConnectorRole.WORKER: + self.connector_scheduler = None + self.connector_worker = RblnNixlConnectorWorker(vllm_config, self.engine_id) + + +class RblnNixlConnectorScheduler(NixlConnectorScheduler): + def __init__(self, vllm_config: VllmConfig, engine_id: str): + super().__init__(vllm_config, engine_id) + + self.use_host_buffer = vllm_config.kv_transfer_config.kv_buffer_device == "cpu" + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + meta = NixlConnectorMetadata() + + # Loop through scheduled reqs and convert to ReqMeta. + for req_id, (req, block_ids) in self._reqs_need_recv.items(): + assert req.kv_transfer_params is not None + meta.add_new_req_to_recv( + request_id=req_id, + local_block_ids=block_ids, + kv_transfer_params=req.kv_transfer_params, + ) + + for req_id, (req, block_ids) in self._reqs_need_save.items(): + assert req.kv_transfer_params is not None + meta.add_new_req_to_save( + request_id=req_id, + local_block_ids=block_ids, + kv_transfer_params=req.kv_transfer_params, + ) + + meta.reqs_to_send = self._reqs_need_send # type: ignore[var-annotated, has-type] + meta.reqs_in_batch = self._reqs_in_batch # type: ignore[var-annotated, has-type] + meta.reqs_not_processed = self._reqs_not_processed # type: ignore[var-annotated, has-type] + + # Clear the list once workers start the transfers + self._reqs_need_recv.clear() + self._reqs_in_batch = set() # type: ignore[var-annotated] + self._reqs_not_processed = set() # type: ignore[var-annotated] + self._reqs_need_send = {} # type: ignore[var-annotated] + + return meta + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, dict[str, Any] | None]: + """ + Once a request is finished, determine whether request blocks + should be freed now or will be sent asynchronously and freed later. + """ + from vllm.v1.request import RequestStatus + + params = request.kv_transfer_params + logger.debug( + "NIXLConnector request_finished(%s), request_status=%s, " + "kv_transfer_params=%s", + request.request_id, + request.status, + params, + ) + if not params: + return False, None + + if params.get("do_remote_prefill"): + # If do_remote_prefill is still True when the request is finished, + # update_state_after_alloc must not have been called (the request + # must have been aborted before it was scheduled). + # To avoid stranding the prefill blocks in the prefill instance, + # we must add empty block_ids to _reqs_need_recv so that our + # worker side will notify and free blocks in the prefill instance. + self._reqs_need_recv[request.request_id] = (request, []) + params["do_remote_prefill"] = False + return False, None + + if not params.get("do_remote_decode"): + return False, None + + if request.request_id in self._reqs_need_save: + del self._reqs_need_save[request.request_id] + + if request.status != RequestStatus.FINISHED_LENGTH_CAPPED: + # Also include the case of a P/D Prefill request with immediate + # block free (eg abort). Stop tracking this request. + self._reqs_not_processed.add(request.request_id) + return False, None + + # TODO: check whether block_ids actually ever be 0. If not we could + # remove the conditional below + delay_free_blocks = len(block_ids) > 0 + + if delay_free_blocks: + # Prefill request on remote. It will be read from D upon completion + logger.debug( + "NIXLConnector request_finished(%s) waiting for %d seconds " + "for remote decode to fetch blocks", + request.request_id, + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT, + ) + self._reqs_need_send[request.request_id] = ( + time.perf_counter() + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT + ) + + return delay_free_blocks, dict( + do_remote_prefill=True, + do_remote_decode=False, + remote_block_ids=block_ids, + remote_engine_id=self.engine_id, + remote_request_id=request.request_id, + remote_host=self.side_channel_host, + remote_port=self.side_channel_port, + tp_size=self.vllm_config.parallel_config.tensor_parallel_size, + ) + + +class RblnNixlConnectorWorker(NixlConnectorWorker): + def __init__(self, vllm_config: VllmConfig, engine_id: str): + super().__init__(vllm_config, engine_id) + + self.use_host_buffer = self.kv_buffer_device == "cpu" + + def initialize_host_xfer_buffer(self, kv_caches: dict[str, torch.Tensor]) -> None: + """ + Initialize transfer buffer in CPU mem for accelerators + NOT directly supported by NIXL (e.g., tpu) + """ + assert self.kv_cache_layout == "HND", ( + "RBLN NIXL Connector only supports HND layout" + ) + xfer_buffers: dict[str, torch.Tensor] = {} + try: + for layer_name, kv_cache in kv_caches.items(): + xfer_buffers[layer_name] = aligned_tensor(kv_cache.numel()).reshape( + kv_cache.shape + ) + except MemoryError as e: + logger.error("RBLNNIXLConnectorWorker gets %s.", e) + raise + + self.host_xfer_buffers = xfer_buffers + + def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp): + """Assign copy (d2h, h2d) operations when host buffer is used.""" + # Set a no-op if the host buffer is not cpu. + if self.kv_buffer_device != "cpu": + return + assert self.use_host_buffer + self.copy_blocks = copy_operation diff --git a/vllm_rbln/platform.py b/vllm_rbln/platform.py index cdbd58549..34478768a 100644 --- a/vllm_rbln/platform.py +++ b/vllm_rbln/platform.py @@ -354,3 +354,13 @@ def get_punica_wrapper(cls) -> str: @classmethod def can_update_inplace(cls) -> bool: return False + + @classmethod + def get_nixl_supported_devices(cls) -> dict[str, tuple[str, ...]]: + return { + "rbln": ("cpu",), + } + + @classmethod + def get_nixl_memory_type(cls) -> str | None: + return "DRAM" diff --git a/vllm_rbln/v1/attention/backends/flash_attention.py b/vllm_rbln/v1/attention/backends/flash_attention.py index 3a3d503b3..ba6b71a9d 100644 --- a/vllm_rbln/v1/attention/backends/flash_attention.py +++ b/vllm_rbln/v1/attention/backends/flash_attention.py @@ -942,7 +942,7 @@ def get_supported_head_sizes(cls) -> list[int]: @staticmethod def get_name() -> str: - return "RBLN_ATTN" + return "FLASH_ATTN" @staticmethod def get_impl_cls() -> type["RBLNFlashAttentionImpl"]: @@ -1075,9 +1075,9 @@ def build( common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, fast_build: bool = False, - num_tokens=None, positions=None, batch_pad=None, + is_prefill=False, ) -> RBLNFlashAttentionMetadata: num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens @@ -1114,18 +1114,10 @@ def build( ).to(torch.int16) seq_lens_tensor = dyn_size_for_partitions - assert num_tokens is not None, ( - "num_tokens is required for RBLN Attention Backend" - ) assert batch_pad is not None, "batch_pad is required for RBLN Attention Backend" - is_prefills = num_computed_tokens[:num_reqs].numpy() < num_tokens[:num_reqs] - 1 - # The prefill and decode cannot be mixed. - assert len(is_prefills) > 0 and all( - is_prefill == is_prefills[0] for is_prefill in is_prefills[:num_reqs] - ) attn_masks = None - if is_prefills[0]: + if is_prefill: # NOTE(jiwoo.park) prefill's block_tables must be a 1D tensor. block_tables_tensor = block_tables_tensor[0] if not self.is_causal: @@ -1181,7 +1173,7 @@ def build( query_lens = seq_lens - num_computed_tokens cache_seq_lens = torch.clamp(num_computed_tokens, max=sliding_window) cache_offsets = cache_seq_lens + query_lens - if not is_prefills[0]: + if not is_prefill: cache_seq_lens = rbln_utils.pad(cache_seq_lens, 0, batch_pad) cache_offsets = rbln_utils.pad(cache_offsets, 0, batch_pad) # Generate sliding window attention mask for decode @@ -1203,7 +1195,7 @@ def build( query_start_loc=query_start_loc, max_seq_len=query_max_seq_len, seq_lens=seq_lens_tensor.to(self.device) - if not self.is_batch_attention_opt or is_prefills[0] or batch_pad <= 1 + if not self.is_batch_attention_opt or is_prefill or batch_pad <= 1 else seq_idx.to(self.device), block_tables=block_tables_tensor.to(self.device), slot_mapping=slot_mapping, @@ -1214,7 +1206,7 @@ def build( prefix_kv_lens=prefix_kv_lens, suffix_kv_lens=suffix_kv_lens, prefix_scheduler_metadata=prefix_scheduler_metadata, - is_prefill=bool(is_prefills[0]), + is_prefill=is_prefill, attn_masks=attn_masks, cache_seq_lens=cache_seq_lens.to(self.device) if cache_seq_lens is not None diff --git a/vllm_rbln/v1/core/rbln_scheduler.py b/vllm_rbln/v1/core/rbln_scheduler.py index 997f7d292..b571171fe 100644 --- a/vllm_rbln/v1/core/rbln_scheduler.py +++ b/vllm_rbln/v1/core/rbln_scheduler.py @@ -364,6 +364,7 @@ def schedule(self) -> SchedulerOutput: request = self.waiting.peek_request() + is_ready = False # KVTransfer: skip request if still waiting for remote kvs. if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: is_ready = self._update_waiting_for_remote_kv(request) @@ -520,9 +521,14 @@ def schedule(self) -> SchedulerOutput: # Therefore, we allocate based on # request.num_tokens - num_computed_tokens, # not num_new_tokens + num_external_computed_tokens. + num_tokens_to_allocate = ( + num_new_tokens + num_external_computed_tokens + if load_kv_async + else request.num_tokens - num_computed_tokens + ) new_blocks = self.kv_cache_manager.allocate_slots( request, - request.num_tokens - num_computed_tokens, + num_tokens_to_allocate, num_new_local_computed_tokens, new_computed_blocks, num_lookahead_tokens=effective_lookahead_tokens, @@ -543,9 +549,12 @@ def schedule(self) -> SchedulerOutput: # tokens, the block may not be fully computed. # Therefore, if the block is not finalized in this iteration, # we must clear the block hash and undo block caching. - undo_uncomputed_block_caching( - request, self.kv_cache_manager, num_computed_tokens + num_new_tokens - ) + if not load_kv_async: + undo_uncomputed_block_caching( + request, + self.kv_cache_manager, + num_computed_tokens + num_new_tokens, + ) # KVTransfer: the connector uses this info to determine # if a load is needed. Note that @@ -610,6 +619,13 @@ def schedule(self) -> SchedulerOutput: if self.ec_connector is not None: self.ec_connector.update_state_after_alloc(request, i) + # If the request' previous state is WAITING_FOR_REMOTE_KVS, + # we can continue the scheduling process. + if is_ready: + # token_budget is only used for assertion checks. + token_budget -= num_new_tokens + continue + # NOTE(RBLN): Reaching this point means that this request # can now be added to the running batch. # However, since we do not support mixed batching for now, diff --git a/vllm_rbln/v1/worker/rbln_model_runner.py b/vllm_rbln/v1/worker/rbln_model_runner.py index e36d73b41..5f480a0d7 100644 --- a/vllm_rbln/v1/worker/rbln_model_runner.py +++ b/vllm_rbln/v1/worker/rbln_model_runner.py @@ -18,9 +18,9 @@ import time from collections import defaultdict from collections.abc import Iterator, Sequence -from contextlib import contextmanager +from contextlib import AbstractContextManager, contextmanager, nullcontext from copy import copy, deepcopy -from typing import TYPE_CHECKING, Any, NamedTuple, Union, cast +from typing import TYPE_CHECKING, Any, Literal, NamedTuple, Union, cast import numpy as np import rebel @@ -39,7 +39,6 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.distributed.eplb.eplb_state import EplbState from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group -from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks from vllm.distributed.parallel_state import get_dp_group, get_pp_group, get_tp_group from vllm.forward_context import set_forward_context from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase @@ -124,6 +123,7 @@ from vllm_rbln.v1.spec_decoding.medusa import RBLNMedusaProposer from vllm_rbln.v1.worker.bucketing import get_bucketing_manager from vllm_rbln.v1.worker.metrics import PerformanceTracker +from vllm_rbln.v1.worker.utils import bind_kv_cache_name if TYPE_CHECKING: from vllm.model_executor.model_loader.tensorizer import TensorizerConfig @@ -283,6 +283,7 @@ def __init__( if "use_global_ctx" in inspect.signature(CompileContext).parameters: compile_ctx_args["use_global_ctx"] = True self.compile_context = CompileContext(**compile_ctx_args) + self.runtime_holder = [] # type: ignore[var-annotated] # Sampler self.use_rbln_sampler = envs.VLLM_RBLN_SAMPLER @@ -312,6 +313,9 @@ def __init__( # self.model: nn.Module # Set after load_model # Initialize in initialize_kv_cache self.kv_caches: list[torch.Tensor] = [] + # Initialize in initialize_kv_cache_tensors + self.cross_layers_kv_cache: torch.Tensor | None = None + self.cross_layers_attn_backend: type[AttentionBackend] | None = None # indexes: [kv_cache_group_id][attn_group] self.attn_groups: list[list[AttentionGroup]] = [] # self.kv_cache_config: KVCacheConfig @@ -1219,12 +1223,13 @@ def _prepare_inputs( num_reqs ) + is_first_request_prefill = self.is_first_request_prefill() (batch_bucket_size, num_padded_tokens, num_tokens_across_dp) = ( self.get_dp_padding( total_num_scheduled_tokens, initial_batch_bucket_size, num_padded_tokens, - bool(self.is_prefills()[0]), + is_first_request_prefill, ) ) assert batch_bucket_size is not None @@ -1302,11 +1307,9 @@ def _prepare_inputs( ) if isinstance(builder, RBLNFlashAttentionMetadataBuilder): - extra_attn_metadata_args["num_tokens"] = ( - self.input_batch.num_tokens_no_spec - ) extra_attn_metadata_args["positions"] = self.positions.cpu extra_attn_metadata_args["batch_pad"] = batch_bucket_size + extra_attn_metadata_args["is_prefill"] = is_first_request_prefill attn_metadata_i = builder.build( common_prefix_len=common_prefix_len, common_attn_metadata=common_attn_metadata, @@ -1357,6 +1360,7 @@ def _compile_model(self, model): "process_group_dict": process_group_dict, "guard_filter_fn": torch.compiler.keep_tensor_guards_unsafe, "mode": "strict", + "_runtime_holder": self.runtime_holder, } if not envs.VLLM_DISABLE_COMPILE_CACHE: logger.info( @@ -1756,7 +1760,7 @@ def _sample( if envs.VLLM_RBLN_METRICS and self.sampler_performance_tracker is not None: self.collect_metrics( self.sampler_performance_tracker, - self.is_prefills()[0], + self.is_first_request_prefill(), start_time=sampler_start_time, end_time=time.perf_counter(), reports=sampler_reports, @@ -2167,7 +2171,6 @@ def _prepare_dummy_inputs( extra_attn_metadata_args = {} if isinstance(builder, RBLNFlashAttentionMetadataBuilder): - extra_attn_metadata_args["num_tokens"] = input_batch.num_tokens extra_attn_metadata_args["positions"] = positions extra_attn_metadata_args["batch_pad"] = batch_bucket_size attn_metadata_i = builder.build( @@ -2494,6 +2497,20 @@ def _bookkeeping_sync( invalid_req_indices, ) + @staticmethod + def maybe_get_kv_connector_output( + scheduler_output: "SchedulerOutput", + wait_for_save: bool, + ) -> AbstractContextManager[KVConnectorOutput | None]: + warm_up_phase = scheduler_output.kv_connector_metadata is None + return ( + KVConnectorModelRunnerMixin._get_kv_connector_output( + scheduler_output, wait_for_save + ) + if has_kv_transfer_group() and not warm_up_phase + else nullcontext() + ) + @torch.inference_mode() def execute_model( self, @@ -2511,7 +2528,8 @@ def execute_model( with record_function_or_nullcontext("Preprocess"): self._update_states(scheduler_output) if not num_scheduled_tokens: - if not has_kv_transfer_group(): + warm_up_phase = scheduler_output.kv_connector_metadata is None + if not has_kv_transfer_group() or warm_up_phase: # Return empty ModelRunnerOutput if there's no work to do. return EMPTY_MODEL_RUNNER_OUTPUT return self.kv_connector_no_forward(scheduler_output, self.vllm_config) @@ -2553,7 +2571,8 @@ def execute_model( ) = self._preprocess(scheduler_output) assert input_ids is not None - is_prefills = self.is_prefills() + is_first_request_prefill = self.is_first_request_prefill() + is_last_prefill_chunk = self.is_last_prefill_chunk() # Padding length for speculative decoding by num_speculative_tokens if scheduler_output.scheduled_spec_decode_tokens: @@ -2610,7 +2629,9 @@ def execute_model( num_padded_tokens=num_padded_tokens, ), record_function_or_nullcontext("Forward"), - self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output, + self.maybe_get_kv_connector_output( + scheduler_output, is_last_prefill_chunk + ) as kv_connector_output, ): if attn_metadata is not None: for attn_metadatum in attn_metadata.values(): @@ -2622,15 +2643,9 @@ def execute_model( positions = positions.view(num_reqs, -1) token_indices = None - if is_prefills[0]: + if is_first_request_prefill: # DO NOT include compute logits if lora_config is enabled token_indices = logits_indices - - # The prefill and decode cannot be mixed. - assert len(is_prefills) > 0 and all( - is_prefill == is_prefills[0] for is_prefill in is_prefills[:num_reqs] - ) - if is_prefills[0]: # prefill chunk padding prefill_size = self.scheduler_config.max_num_batched_tokens input_ids = rbln_utils.pad(input_ids, -1, prefill_size) @@ -2667,7 +2682,7 @@ def execute_model( lora_ids, self.lora_manager._adapter_manager.lora_index_to_id, batch_bucket_size, - is_prefills[0], + is_first_request_prefill, self.lora_config.max_loras, self.device, ) @@ -2706,7 +2721,7 @@ def execute_model( device_time = reports[0].get("total_device", None) ccl_time = reports[0].get("total_ccl", None) - if is_prefills[0]: + if is_first_request_prefill: self.performance_tracker.record_prefill( model_execution_time, num_scheduled_tokens, @@ -2772,7 +2787,7 @@ def execute_model( # SHOULD resolve the batch dimension. hidden_states = hidden_states.flatten(0, -2) - if is_prefills[0]: # prefill + if is_first_request_prefill: # prefill sample_hidden_states = hidden_states[logits_indices] logits = self.compute_logits(sample_hidden_states) else: # decode @@ -2784,14 +2799,9 @@ def execute_model( else: selected_token_indices = logits_indices assert selected_token_indices.dim() == 1 - if is_prefills[0]: # prefill + if is_first_request_prefill: # prefill assert selected_token_indices.size(0) == 1 - num_computed = self.input_batch.num_computed_tokens_cpu - num_prompted = self.input_batch.num_prompt_tokens - is_last_prefill = ( - num_computed + self.max_num_tokens - ) >= num_prompted - if not is_last_prefill[0]: # noqa: SIM108 + if not is_last_prefill_chunk: # noqa: SIM108 # chunked prefill(#0~#N-1, intermediate) # token_indices = torch.tensor([max_num_seqs-1]) # selected = torch.tensor([]) @@ -3946,8 +3956,14 @@ def initialize_kv_cache_tensors( num_attn_module, ) if not self.model_config.enforce_eager and envs.VLLM_RBLN_COMPILE_MODEL: - for kv_cache in self.kv_caches: - self.compile_context.mark_static_address(kv_cache) + kv_cache_names: list[str] = [] + bind_kv_cache_name( + kv_caches, + kv_cache_names, + num_attn_module, + ) + for kv_cache, kv_cache_name in zip(self.kv_caches, kv_cache_names): + self.compile_context.mark_static_address(kv_cache, f"{kv_cache_name}") return kv_caches @@ -4023,7 +4039,38 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: ) else: kv_transfer_group.register_kv_caches(kv_caches) - kv_transfer_group.set_host_xfer_buffer_ops(copy_kv_blocks) + + def rbln_copy_kv_blocks( + src_kv_caches: dict[str, torch.Tensor], + dst_kv_caches: dict[str, torch.Tensor], + src_block_ids: list[int], + dst_block_ids: list[int], + direction: Literal["h2d", "d2h"], + ) -> None: + """Copy kv blocks between different buffers.""" + if ( + not src_kv_caches + or not dst_kv_caches + or not src_block_ids + or not dst_block_ids + or len(src_block_ids) != len(dst_block_ids) + ): + return + assert len(self.runtime_holder) > 0 + runtime = self.runtime_holder[0] + if direction == "h2d": + kv_caches = src_kv_caches + copy_fn = runtime._update_kv_cache + else: + kv_caches = dst_kv_caches + copy_fn = runtime._fetch_kv_cache + + for idx in src_block_ids: + for kv_name, kv_cache in kv_caches.items(): + block_size = kv_cache.shape[-2] + copy_fn(kv_cache.data_ptr(), idx, 0, block_size, kv_name) + + kv_transfer_group.set_host_xfer_buffer_ops(rbln_copy_kv_blocks) if self.dcp_world_size > 1: layer_type = cast(type[Any], AttentionLayerBase) @@ -4119,6 +4166,15 @@ def is_prefills(self) -> np.ndarray: < self.input_batch.num_tokens_no_spec - 1 ) + def is_first_request_prefill(self) -> bool: + return bool(self.is_prefills()[0]) + + def is_last_prefill_chunk(self) -> bool: + num_computed = self.input_batch.num_computed_tokens_cpu + num_prompted = self.input_batch.num_prompt_tokens + is_last_prefill = (num_computed + self.max_num_tokens) >= num_prompted + return bool(is_last_prefill[0]) + def use_wrapped_compute_logits(self) -> bool: return not ( self.lora_config is not None diff --git a/vllm_rbln/v1/worker/rbln_worker.py b/vllm_rbln/v1/worker/rbln_worker.py index bc229b986..4bd5e3917 100644 --- a/vllm_rbln/v1/worker/rbln_worker.py +++ b/vllm_rbln/v1/worker/rbln_worker.py @@ -36,8 +36,15 @@ init_distributed_environment, set_custom_all_reduce, ) -from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized -from vllm.distributed.parallel_state import get_pp_group +from vllm.distributed.kv_transfer import ( + ensure_kv_transfer_initialized, + get_kv_transfer_group, + has_kv_transfer_group, +) +from vllm.distributed.parallel_state import ( + get_pp_group, + get_tp_group, +) from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed from vllm.platforms import current_platform @@ -324,6 +331,21 @@ def determine_available_memory(self) -> int: return available_memory_estimate + def get_kv_connector_handshake_metadata(self) -> dict | None: + """Get KV connector metadata from this worker if available.""" + + if not has_kv_transfer_group(): + return None + + connector = get_kv_transfer_group() + # Return None for connectors that don't need to exchange handshake + # metadata across workers. + if (metadata := connector.get_handshake_metadata()) is None: + return None + + tp_rank = get_tp_group().rank_in_group + return {tp_rank: metadata} + def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: return self.model_runner.get_kv_cache_spec() diff --git a/vllm_rbln/v1/worker/utils.py b/vllm_rbln/v1/worker/utils.py index 5ccdfeb3b..6eb318890 100644 --- a/vllm_rbln/v1/worker/utils.py +++ b/vllm_rbln/v1/worker/utils.py @@ -16,9 +16,12 @@ import math import os import platform +from collections import defaultdict from collections.abc import Callable +import torch from vllm.config import ModelConfig, ParallelConfig +from vllm.model_executor.models.utils import extract_layer_index from vllm.platforms import CpuArchEnum, current_platform from vllm.platforms.cpu import CpuPlatform, LogicalCPUInfo @@ -396,3 +399,53 @@ def set_omp_num_threads( rank, local_rank, ) + + +def bind_kv_cache_name( + kv_caches: dict[str, torch.Tensor], + runner_kv_cache_names: list[str], + num_attn_module: int = 1, +) -> None: + """ + Bind the allocated KV cache name to ModelRunner and forward context so + that the KV cache can be used in the forward pass. + + This function: + 1) Fills the ModelRunner's kv cache name list (`runner_kv_cache_names`) with + kv_caches. + 2) Copied and Modified from vllm.v1.worker.utils.bind_kv_cache + Args: + kv_caches: The allocated kv_caches with layer names as keys. + runner_kv_cache_names: The kv_cache name list declared by ModelRunner. + """ + # Bind kv_cache names to ModelRunner + assert len(runner_kv_cache_names) == 0 + + # Convert kv_caches dict to a list of tensors in the order of layer_index. + index2name = defaultdict(list) + for layer_name in kv_caches: + index2name[extract_layer_index(layer_name, num_attn_module)].append(layer_name) + + for layer_index in sorted(index2name.keys()): + layer_names = index2name[layer_index] + if len(layer_names) > 1: + # One typical case is encoder-decoder model, e.g., bart. + # The cross attention and self attention in the same decoder layer + # has different layer_name but the same layer_index. + + # TODO - analyze where runner_kv_caches is used and the right + # way to ensure it properly reflects multiple attention layers + # in the same decoder block. + if ( + current_platform.is_cuda_alike() + or current_platform.is_xpu() + or current_platform.is_cpu() + ): + # We know that the GPU / CPU runner is not impacted by this + # case. Some test code depends on runner_kv_caches, but + # not in a way that's impacted by ignoring this. + pass + else: + raise NotImplementedError + layer_name = layer_names[0] + runner_kv_cache_names.append(layer_name) From baa04f86281a4ec5c4cc3c744020d219879abb10 Mon Sep 17 00:00:00 2001 From: Youngkyu Choi Date: Tue, 24 Mar 2026 21:15:08 +0900 Subject: [PATCH 2/3] add tools/install_nixl_from_source_ubuntu.py --- tools/install_nixl_from_source_ubuntu.py | 307 +++++++++++++++++++++++ 1 file changed, 307 insertions(+) create mode 100644 tools/install_nixl_from_source_ubuntu.py diff --git a/tools/install_nixl_from_source_ubuntu.py b/tools/install_nixl_from_source_ubuntu.py new file mode 100644 index 000000000..c9eb307d0 --- /dev/null +++ b/tools/install_nixl_from_source_ubuntu.py @@ -0,0 +1,307 @@ +# Copyright 2025 Rebellions Inc. All rights reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# install_prerequisites.py +import argparse +import glob +import json +import os +import subprocess +import sys +import urllib.request + +# --- Configuration --- +WHEELS_CACHE_HOME = os.environ.get("WHEELS_CACHE_HOME", "/tmp/wheels_cache") +ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) +UCX_DIR = os.path.join("/tmp", "ucx_source") +NIXL_DIR = os.path.join("/tmp", "nixl_source") +UCX_INSTALL_DIR = os.path.join("/tmp", "ucx_install") +UCX_REPO_URL = "https://github.com/openucx/ucx.git" +NIXL_REPO_URL = "https://github.com/ai-dynamo/nixl.git" + + +# --- Helper Functions --- +def get_latest_nixl_version(): + """Helper function to get latest release version of NIXL""" + try: + nixl_release_url = "https://api.github.com/repos/ai-dynamo/nixl/releases/latest" + with urllib.request.urlopen(nixl_release_url) as response: + data = json.load(response) + return data.get("tag_name", "0.7.0") + except Exception: + return "0.7.0" + + +NIXL_VERSION = os.environ.get("NIXL_VERSION", get_latest_nixl_version()) + + +def run_command(command, cwd=".", env=None): + """Helper function to run a shell command and check for errors.""" + print(f"--> Running command: {' '.join(command)} in '{cwd}'", flush=True) + subprocess.check_call(command, cwd=cwd, env=env) + + +def is_pip_package_installed(package_name): + """Checks if a package is installed via pip without raising an exception.""" + result = subprocess.run( + [sys.executable, "-m", "pip", "show", package_name], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + return result.returncode == 0 + + +def find_nixl_wheel_in_cache(cache_dir): + """Finds a nixl wheel file in the specified cache directory.""" + # The repaired wheel will have a 'manylinux' tag, but this glob still works. + search_pattern = os.path.join(cache_dir, f"nixl*{NIXL_VERSION}*.whl") + wheels = glob.glob(search_pattern) + if wheels: + # Sort to get the most recent/highest version if multiple exist + wheels.sort() + return wheels[-1] + return None + + +def get_site_packages_dir(): + """Returns the site-packages directory for the current Python.""" + result = subprocess.run( + [ + sys.executable, + "-c", + "import site; print(site.getsitepackages()[0])", + ], + capture_output=True, + text=True, + check=True, + ) + return result.stdout.strip() + + +def install_nixl_shim_for_meta_compat(site_packages_dir): + """ + Installs a minimal 'nixl' shim so that `import nixl` and + `from nixl._api import nixl_agent` work like with the PyPI meta package. + The source-built wheel installs as 'nixl-cu12' (import name: nixl_cu12); + this shim re-exposes it under the 'nixl' namespace. + """ + nixl_dir = os.path.join(site_packages_dir, "nixl") + os.makedirs(nixl_dir, exist_ok=True) + init_py = os.path.join(nixl_dir, "__init__.py") + shim_content = '''# Shim so that source-installed nixl-cu12 is usable as "nixl" (meta-package style). +# See install_nixl_from_source_ubuntu.py +import sys +import nixl_cu12 +sys.modules["nixl._api"] = nixl_cu12._api +sys.modules["nixl._bindings"] = nixl_cu12._bindings +''' + with open(init_py, "w") as f: + f.write(shim_content) + print("--> Installed 'nixl' shim for meta-package-style import (import nixl, from nixl._api import ...)", flush=True) + + +def install_system_dependencies(): + """Installs required system packages using apt-get if run as root.""" + if os.geteuid() != 0: + print("\n---", flush=True) + print( + "WARNING: Not running as root. \ + Skipping system dependency installation.", + flush=True, + ) + print( + "Please ensure the listed packages are installed on your system:", + flush=True, + ) + print( + " patchelf build-essential git cmake ninja-build \ + autotools-dev automake meson libtool libtool-bin", + flush=True, + ) + print("---\n", flush=True) + return + + print("--- Running as root. Installing system dependencies... ---", flush=True) + apt_packages = [ + "patchelf", # <-- Add patchelf here + "build-essential", + "git", + "cmake", + "ninja-build", + "autotools-dev", + "automake", + "meson", + "libtool", + "libtool-bin", + "pkg-config", + ] + run_command(["apt-get", "update"]) + run_command(["apt-get", "install", "-y"] + apt_packages) + print("--- System dependencies installed successfully. ---\n", flush=True) + + +def build_and_install_prerequisites(args): + """Builds UCX and NIXL from source, creating a self-contained wheel.""" + + if not args.force_reinstall and is_pip_package_installed("nixl"): + print("--> NIXL is already installed. Nothing to do.", flush=True) + return + + cached_wheel = find_nixl_wheel_in_cache(WHEELS_CACHE_HOME) + if not args.force_reinstall and cached_wheel: + print( + f"\n--> Found self-contained wheel: \ + {os.path.basename(cached_wheel)}.", + flush=True, + ) + print("--> Installing from cache, skipping all source builds.", flush=True) + install_command = [sys.executable, "-m", "pip", "install", cached_wheel] + run_command(install_command) + install_nixl_shim_for_meta_compat(get_site_packages_dir()) + print("\n--- Installation from cache complete. ---", flush=True) + return + + print( + "\n--> No installed package or cached wheel found. \ + Starting full build process...", + flush=True, + ) + print("\n--> Installing auditwheel...", flush=True) + run_command([sys.executable, "-m", "pip", "install", "auditwheel"]) + install_system_dependencies() + ucx_install_path = os.path.abspath(UCX_INSTALL_DIR) + print(f"--> Using wheel cache directory: {WHEELS_CACHE_HOME}", flush=True) + os.makedirs(WHEELS_CACHE_HOME, exist_ok=True) + + # -- Step 1: Build UCX from source -- + print("\n[1/3] Configuring and building UCX from source...", flush=True) + if not os.path.exists(UCX_DIR): + run_command(["git", "clone", UCX_REPO_URL, UCX_DIR]) + ucx_source_path = os.path.abspath(UCX_DIR) + run_command(["git", "checkout", "v1.19.x"], cwd=ucx_source_path) + run_command(["./autogen.sh"], cwd=ucx_source_path) + configure_command = [ + "./configure", + f"--prefix={ucx_install_path}", + "--enable-shared", + "--disable-static", + "--disable-doxygen-doc", + "--enable-optimizations", + "--enable-cma", + "--enable-devel-headers", + "--with-verbs", + "--enable-mt", + "--with-ze=no", + ] + run_command(configure_command, cwd=ucx_source_path) + run_command(["make", "-j", str(os.cpu_count() or 1)], cwd=ucx_source_path) + run_command(["make", "install"], cwd=ucx_source_path) + print("--- UCX build and install complete ---", flush=True) + + # -- Step 2: Build NIXL wheel from source -- + print("\n[2/3] Building NIXL wheel from source...", flush=True) + if not os.path.exists(NIXL_DIR): + run_command(["git", "clone", NIXL_REPO_URL, NIXL_DIR]) + else: + run_command(["git", "fetch", "--tags"], cwd=NIXL_DIR) + run_command(["git", "checkout", NIXL_VERSION], cwd=NIXL_DIR) + print(f"--> Checked out NIXL version: {NIXL_VERSION}", flush=True) + + build_env = os.environ.copy() + build_env["PKG_CONFIG_PATH"] = os.path.join(ucx_install_path, "lib", "pkgconfig") + ucx_lib_path = os.path.join(ucx_install_path, "lib") + ucx_plugin_path = os.path.join(ucx_lib_path, "ucx") + existing_ld_path = os.environ.get("LD_LIBRARY_PATH", "") + build_env["LD_LIBRARY_PATH"] = ( + f"{ucx_lib_path}:{ucx_plugin_path}:{existing_ld_path}".strip(":") + ) + build_env["LDFLAGS"] = "-Wl,-rpath,$ORIGIN" + print(f"--> Using LD_LIBRARY_PATH: {build_env['LD_LIBRARY_PATH']}", flush=True) + + temp_wheel_dir = os.path.join(ROOT_DIR, "temp_wheelhouse") + run_command( + [ + sys.executable, + "-m", + "pip", + "wheel", + ".", + "--no-deps", + f"--wheel-dir={temp_wheel_dir}", + ], + cwd=os.path.abspath(NIXL_DIR), + env=build_env, + ) + + # -- Step 3: Repair the wheel by copying UCX libraries -- + print("\n[3/3] Repairing NIXL wheel to include UCX libraries...", flush=True) + unrepaired_wheel = find_nixl_wheel_in_cache(temp_wheel_dir) + if not unrepaired_wheel: + raise RuntimeError("Failed to find the NIXL wheel after building it.") + + # We tell auditwheel to ignore the plugin that mesonpy already handled. + auditwheel_command = [ + "auditwheel", + "repair", + "--exclude", + "libplugin_UCX.so", # <-- Exclude because mesonpy already includes it + unrepaired_wheel, + f"--wheel-dir={WHEELS_CACHE_HOME}", + ] + run_command(auditwheel_command, env=build_env) + + # --- CLEANUP --- + # No more temporary files to remove, just the temp wheelhouse + run_command(["rm", "-rf", temp_wheel_dir]) + # --- END CLEANUP --- + + newly_built_wheel = find_nixl_wheel_in_cache(WHEELS_CACHE_HOME) + if not newly_built_wheel: + raise RuntimeError("Failed to find the repaired NIXL wheel.") + + print( + f"--> Successfully built self-contained wheel: \ + {os.path.basename(newly_built_wheel)}. Now installing...", + flush=True, + ) + install_command = [ + sys.executable, + "-m", + "pip", + "install", + "--no-deps", # w/o "no-deps", it will install cuda-torch + newly_built_wheel, + ] + if args.force_reinstall: + install_command.insert(-1, "--force-reinstall") + + run_command(install_command) + install_nixl_shim_for_meta_compat(get_site_packages_dir()) + print("--- NIXL installation complete ---", flush=True) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Build and install UCX and NIXL dependencies." + ) + parser.add_argument( + "--force-reinstall", + action="store_true", + help="Force rebuild and reinstall of UCX and NIXL \ + even if they are already installed.", + ) + args = parser.parse_args() + build_and_install_prerequisites(args) From 51474c64e01ae866de086dd12fd5216aaad8c2e9 Mon Sep 17 00:00:00 2001 From: Jinseok Lee Date: Tue, 7 Apr 2026 16:28:23 +0900 Subject: [PATCH 3/3] formatter fixed Signed-off-by: Jinseok Lee --- tools/install_nixl_from_source_ubuntu.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tools/install_nixl_from_source_ubuntu.py b/tools/install_nixl_from_source_ubuntu.py index c9eb307d0..3c27fda5b 100644 --- a/tools/install_nixl_from_source_ubuntu.py +++ b/tools/install_nixl_from_source_ubuntu.py @@ -101,16 +101,19 @@ def install_nixl_shim_for_meta_compat(site_packages_dir): nixl_dir = os.path.join(site_packages_dir, "nixl") os.makedirs(nixl_dir, exist_ok=True) init_py = os.path.join(nixl_dir, "__init__.py") - shim_content = '''# Shim so that source-installed nixl-cu12 is usable as "nixl" (meta-package style). + shim_content = """# Shim so that source-installed nixl-cu12 is usable as "nixl" (meta-package style). # See install_nixl_from_source_ubuntu.py import sys import nixl_cu12 sys.modules["nixl._api"] = nixl_cu12._api sys.modules["nixl._bindings"] = nixl_cu12._bindings -''' +""" with open(init_py, "w") as f: f.write(shim_content) - print("--> Installed 'nixl' shim for meta-package-style import (import nixl, from nixl._api import ...)", flush=True) + print( + "--> Installed 'nixl' shim for meta-package-style import (import nixl, from nixl._api import ...)", + flush=True, + ) def install_system_dependencies():