Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions examples/ray/multi_nodes/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,14 @@ TensorRT-LLM supports a prototype [Ray orchestrator](../README.md) as an alterna
2. Once on the head node, launch a multi-node Ray cluster:
```shell
# Remember to set CONTAINER and MOUNTS env vars or variables inside the script to your path.
# You can add the TensorRT-LLM installation command in this script if it is not preinstalled in your container.
>> bash -e run_cluster.sh
```

3. Enter the head container and run your TensorRT-LLM driver script

Note that this step requires TensorRT-LLM to be installed in the containers on all nodes. If it isn’t, install it manually inside each node’s container.

```shell
# On the head node
>> sacct
Expand All @@ -31,9 +35,7 @@ TensorRT-LLM supports a prototype [Ray orchestrator](../README.md) as an alterna
>> enroot list -f # get process id
>> enroot exec <process id> bash

# Under your work directory:
>> pip install -e . # if needed
# You can change this script to a model and parallel settings effective for multi-node inference (e.g., TP8 or TP4PP4)
# You can change this script to a model and parallel settings effective for multi-node inference (e.g., TP8 or TP4PP4).
>> python examples/ray/llm_inference_async_ray.py
```

Expand Down
1 change: 0 additions & 1 deletion examples/ray/multi_nodes/run_cluster.sh
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ echo -e "${BLUE}[INFO] Logs : $LOG_DIR${RESET}"
########################################################

# enabled dashboard only for debug
# Add cd tekit/ and pip install -e . to head_cmd and worker_cmd if needed
# Add apt-get install -y --no-install-recommends libzmq3-dev for multi-node disagg

head_cmd=$(cat <<EOF
Expand Down
57 changes: 19 additions & 38 deletions tensorrt_llm/_torch/distributed/communicator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from abc import ABC, abstractmethod
from functools import wraps
from typing import Optional, ValuesView
from typing import Optional

import numpy as np
import ray
Expand Down Expand Up @@ -167,6 +167,23 @@ def pp_broadcast(self, obj, root=0):
return self.pp_comm.bcast(obj, root)


class MultiHandleWrapper:
"""
Wrapper that encapsulates multiple handles and provides a single wait() interface
to unify the API between MPIDist and TorchDist.
"""

def __init__(self, handles):
self.handles = handles if isinstance(handles, list) else [handles]

def wait(self):
for handle in self.handles:
try:
handle.wait()
except Exception as e:
raise RuntimeError(f"Asynchronous operation failed: {e}") from e


class TorchDist(Distributed):

@property
Expand Down Expand Up @@ -364,7 +381,7 @@ def isend_object(self, obj, dest, tag=0):
dst=dest,
tag=tag))
works.append(torch.distributed.isend(input_tensor, dst=dest, tag=tag))
return works
return MultiHandleWrapper(works)

@log_op
def recv_object_from_isend(self, src, tag):
Expand All @@ -376,42 +393,6 @@ def recv_object_from_isend(self, src, tag):
return _tensor_to_object(recv_tensor, bytes_size,
torch.distributed.group.WORLD)

@log_op
def isend_tensor_list(self,
tensor_list: ValuesView[torch.Tensor],
dest,
tag=0):
if len(tensor_list) == 0:
return None
elif len(tensor_list) == 1:
return [self.isend_tensor(next(iter(tensor_list)), dest, tag)]
return [dist.isend(torch.cat(tensor_list), dst=dest, tag=tag)]

@log_op
def recv_tensor_list(self,
tensor_list: ValuesView[torch.Tensor],
src,
tag=0):
if len(tensor_list) == 0:
return []

first_tensor = next(iter(tensor_list))
if len(tensor_list) == 1:
return [self.recv_tensor(first_tensor, src, tag)]

# Receive tensors
recv_tensor = torch.empty_like(torch.cat(
[t.to('meta') for t in tensor_list]),
device=first_tensor.device)
dist.recv(recv_tensor, src, tag)
# Assign to tensor_list
recv_tensor = recv_tensor.flatten()
offset = 0
for t in tensor_list:
t.copy_(recv_tensor[offset:offset + t.numel()].reshape(t.shape))
offset += t.numel()
return tensor_list

@log_op
def allreduce(self,
obj: int | float | torch.Tensor,
Expand Down
15 changes: 4 additions & 11 deletions tensorrt_llm/_torch/pyexecutor/executor_request_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,27 +571,20 @@ def _broadcast_new_requests(

# Tag for communication
tag = self.dist.pp_size # Use pp_size as tag to avoid conflicts
works = []

# Send payloads
if not self.dist.is_first_pp_rank:
payloads = self.dist.recv_object(self.dist.prev_pp_rank, tag)

if not self.dist.is_last_pp_rank:
if self._disable_mpi:
works.extend(
self.dist.isend_object(payloads, self.dist.next_pp_rank,
tag))
isend_payload = self.dist.isend_object(payloads,
self.dist.next_pp_rank,
tag)
isend_payload.wait()
else:
self.dist.send_object(payloads, self.dist.next_pp_rank, tag)

for work in works:
try:
work.wait()
except Exception as e:
raise RuntimeError(
f"Asynchronous broadcast operation failed: {e}") from e

return payloads

def _attach_py_objects_to_requests(self, requests: List[RequestQueueItem],
Expand Down
10 changes: 1 addition & 9 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,15 +961,7 @@ def _executor_loop_pp(self):

def wait_on_pp_send_handles(self, microbatch_id):
if self.send_handles[microbatch_id] is not None:
if self._disable_mpi:
for work in self.send_handles[microbatch_id]:
try:
work.wait()
except Exception as e:
raise RuntimeError(
f"Asynchronous send operation failed: {e}") from e
else:
self.send_handles[microbatch_id].wait()
self.send_handles[microbatch_id].wait()
self.send_handles[microbatch_id] = None

def _prepare_and_schedule_batch(self):
Expand Down
Loading