11import ctypes
22from datetime import timedelta
3- from typing import Any , List , Optional
3+ from typing import Any
44
55import torch
66from torch .distributed import ReduceOp
77from vllm .distributed .device_communicators .pynccl import PyNcclCommunicator
88from vllm .distributed .device_communicators .pynccl_wrapper import (
99 Function ,
1010 NCCLLibrary ,
11- buffer_type ,
1211 ncclComm_t ,
1312 ncclResult_t ,
1413)
1514from vllm .distributed .utils import StatelessProcessGroup
1615from vllm .utils import current_stream
16+
1717from checkpoint_engine .distributed .base import Distributed , _common_all_gather_object
1818
1919
@@ -132,7 +132,6 @@ def init_process_group(
132132 self .comm = self .pynccl .comm
133133 self .initialized = True
134134
135-
136135 def destroy_process_group (
137136 self ,
138137 group = None ,
@@ -155,12 +154,7 @@ def is_initialized(self) -> bool:
155154 return self .initialized
156155
157156
158- def all_gather_object (
159- self ,
160- object_list : list [Any ],
161- obj : Any ,
162- group = None
163- ):
157+ def all_gather_object (self , object_list : list [Any ], obj : Any , group = None ):
164158 assert self .initialized , "not initialized"
165159
166160 if group :
@@ -175,12 +169,7 @@ def all_gather_object(
175169 self .pynccl .comm = self .comm
176170
177171
178- def all_reduce (
179- self ,
180- tensor : torch .Tensor ,
181- op = ReduceOp .SUM ,
182- group = None
183- ):
172+ def all_reduce (self , tensor : torch .Tensor , op = ReduceOp .SUM , group = None ):
184173 assert self .initialized , "not initialized"
185174
186175 if group :
@@ -196,12 +185,7 @@ def all_reduce(
196185 self .pynccl .comm = self .comm
197186
198187
199- def broadcast (
200- self ,
201- tensor : torch .Tensor ,
202- src = None ,
203- group = None
204- ):
188+ def broadcast (self , tensor : torch .Tensor , src = None , group = None ):
205189 assert self .initialized , "not initialized"
206190
207191 if group :
@@ -221,10 +205,7 @@ def broadcast(
221205 self .pynccl .rank = self .rank
222206
223207
224- def barrier (
225- self ,
226- group = None
227- ):
208+ def barrier (self , group = None ):
228209 assert self .initialized , "not initialized"
229210
230211 if group :
@@ -240,10 +221,7 @@ def barrier(
240221 self .pynccl .comm = self .comm
241222
242223
243- def new_group (
244- self ,
245- ranks
246- ):
224+ def new_group (self , ranks ):
247225 assert self .initialized , "not initialized"
248226
249227 # ranks is None or []
0 commit comments