Skip to content

Commit d2fce90

Browse files
author
kip-cxj
committed
add statelessprocessgroup to extend collective library
1 parent 009082d commit d2fce90

File tree

6 files changed

+876
-7
lines changed

6 files changed

+876
-7
lines changed
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from .base import (
2+
Distributed,
3+
DistributedProcessGroup,
4+
all_gather_object,
5+
all_reduce,
6+
barrier,
7+
broadcast,
8+
destroy_process_group,
9+
init_process_group,
10+
is_initialized,
11+
new_group,
12+
use_backend,
13+
)
14+
15+
16+
__all__ = [
17+
"Distributed",
18+
"DistributedProcessGroup",
19+
"all_gather_object",
20+
"all_reduce",
21+
"barrier",
22+
"broadcast",
23+
"destroy_process_group",
24+
"init_process_group",
25+
"is_initialized",
26+
"new_group",
27+
"use_backend",
28+
]
Lines changed: 288 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,288 @@
1+
import importlib
2+
import io
3+
import pickle
4+
from abc import ABC, abstractmethod
5+
from datetime import timedelta
6+
from typing import Any, Protocol
7+
8+
import torch
9+
import torch.distributed as torch_dist
10+
11+
12+
class CommunicatorProtocol(Protocol):
13+
def all_gather(self, *args: Any, **kwargs: Any) -> torch.Tensor: ...
14+
15+
16+
class CommGroup:
17+
def __init__(self, comm_handle: int, ranks: list[int]):
18+
self._comm = comm_handle
19+
self._ranks = ranks
20+
21+
@property
22+
def handle(self) -> int:
23+
return self._comm
24+
25+
@property
26+
def ranks(self) -> list[int]:
27+
return self._ranks
28+
29+
30+
DistributedProcessGroup = torch_dist.ProcessGroup | CommGroup
31+
32+
33+
class Distributed(ABC):
34+
@abstractmethod
35+
def init_process_group(
36+
self,
37+
rank: int,
38+
world_size: int,
39+
store: torch_dist.TCPStore,
40+
**kwargs,
41+
):
42+
raise NotImplementedError
43+
44+
@abstractmethod
45+
def destroy_process_group(
46+
self,
47+
group: DistributedProcessGroup | None = None,
48+
):
49+
raise NotImplementedError
50+
51+
@abstractmethod
52+
def is_initialized(self) -> bool:
53+
raise NotImplementedError
54+
55+
@abstractmethod
56+
def all_gather_object(
57+
self,
58+
object_list: list[Any],
59+
obj: Any,
60+
group: DistributedProcessGroup | None = None,
61+
):
62+
raise NotImplementedError
63+
64+
@abstractmethod
65+
def all_reduce(
66+
self,
67+
tensor: torch.Tensor,
68+
op: torch_dist.ReduceOp.RedOpType,
69+
group: DistributedProcessGroup | None = None,
70+
**kwargs,
71+
):
72+
raise NotImplementedError
73+
74+
@abstractmethod
75+
def broadcast(
76+
self,
77+
tensor: torch.Tensor,
78+
src: int,
79+
group: DistributedProcessGroup | None = None,
80+
**kwargs,
81+
):
82+
raise NotImplementedError
83+
84+
@abstractmethod
85+
def barrier(
86+
self,
87+
group: DistributedProcessGroup | None = None,
88+
**kwargs,
89+
):
90+
raise NotImplementedError
91+
92+
@abstractmethod
93+
def new_group(
94+
self,
95+
ranks: list[int],
96+
**kwargs,
97+
):
98+
raise NotImplementedError
99+
100+
101+
class TorchBackend(Distributed):
102+
def init_process_group(
103+
self,
104+
rank: int,
105+
world_size: int,
106+
store: torch_dist.TCPStore,
107+
**kwargs,
108+
):
109+
backend = kwargs.get("backend", "nccl")
110+
timeout = kwargs.get("timeout", timedelta(minutes=10))
111+
112+
torch_dist.init_process_group(
113+
backend=backend,
114+
world_size=world_size,
115+
rank=rank,
116+
timeout=timeout,
117+
store=store,
118+
)
119+
120+
def destroy_process_group(self, group: DistributedProcessGroup | None = None):
121+
torch_dist.destroy_process_group(group)
122+
123+
def is_initialized(self) -> bool:
124+
return torch_dist.is_initialized()
125+
126+
def all_gather_object(
127+
self, object_list: list[Any], obj: Any, group: DistributedProcessGroup | None = None
128+
):
129+
torch_dist.all_gather_object(object_list, obj, group)
130+
131+
def all_reduce(
132+
self,
133+
tensor: torch.Tensor,
134+
op: torch_dist.ReduceOp.RedOpType = torch_dist.ReduceOp.SUM,
135+
group: DistributedProcessGroup | None = None,
136+
**kwargs,
137+
):
138+
torch_dist.all_reduce(tensor, op, group, **kwargs)
139+
140+
def broadcast(
141+
self,
142+
tensor: torch.Tensor,
143+
src: int = 0,
144+
group: DistributedProcessGroup | None = None,
145+
**kwargs,
146+
):
147+
torch_dist.broadcast(tensor, src, group, **kwargs)
148+
149+
def barrier(self, group: DistributedProcessGroup | None = None, **kwargs):
150+
torch_dist.barrier(group, **kwargs)
151+
152+
def new_group(self, ranks: list[int], **kwargs) -> DistributedProcessGroup | None:
153+
return torch_dist.new_group(ranks, **kwargs)
154+
155+
156+
# specific device instance
157+
_BACKEND_INSTANCE: Distributed = TorchBackend()
158+
159+
_pickler = pickle.Pickler
160+
_unpickler = pickle.Unpickler
161+
162+
163+
def _object_to_tensor(obj: Any, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
164+
f = io.BytesIO()
165+
_pickler(f).dump(obj)
166+
byte_storage = torch.ByteStorage._from_buffer(f.getvalue())
167+
byte_tensor = torch.ByteTensor(byte_storage).to(device)
168+
local_size = torch.LongTensor([byte_tensor.numel()]).to(device)
169+
return byte_tensor, local_size
170+
171+
172+
def _tensor_to_object(tensor: torch.Tensor, tensor_size: int) -> Any:
173+
tensor = tensor.cpu()
174+
buf = tensor.numpy().tobytes()[:tensor_size]
175+
return _unpickler(io.BytesIO(buf)).load()
176+
177+
178+
def _flatten_for_scatter_gather(
179+
tensor_list: list[torch.Tensor], copy: bool = False
180+
) -> torch.Tensor:
181+
if not tensor_list:
182+
raise RuntimeError("Received an empty list.")
183+
t = tensor_list[0]
184+
buffer_shape = [len(tensor_list)] + list(t.shape)
185+
186+
buffer = torch.empty(tuple(buffer_shape), dtype=t.dtype, device=t.device)
187+
if copy:
188+
for i, tensor in enumerate(tensor_list):
189+
buffer[i].copy_(tensor)
190+
return buffer
191+
192+
193+
def _common_all_gather_object(
194+
comm: CommunicatorProtocol,
195+
device: torch.device,
196+
world_size: int,
197+
object_list: list[Any],
198+
object: Any,
199+
):
200+
input_tensor, local_size = _object_to_tensor(object, device)
201+
object_sizes_tensor = torch.empty(world_size, dtype=torch.long, device=device)
202+
comm.all_gather(object_sizes_tensor, local_size)
203+
object_size_list = [object_sizes_tensor[i].unsqueeze(dim=0) for i in range(world_size)]
204+
max_object_size = int(max(object_size_list).item())
205+
input_tensor.resize_(max_object_size)
206+
coalesced_output_tensor = torch.empty(
207+
max_object_size * world_size, dtype=torch.uint8, device=device
208+
)
209+
210+
comm.all_gather(coalesced_output_tensor, input_tensor)
211+
output_tensors = [
212+
coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)]
213+
for i in range(world_size)
214+
]
215+
for i, tensor in enumerate(output_tensors):
216+
tensor = tensor.type(torch.uint8)
217+
tensor_size = object_size_list[i]
218+
object_list[i] = _tensor_to_object(tensor, tensor_size)
219+
220+
221+
def use_backend(backend: str | None):
222+
global _BACKEND_INSTANCE
223+
224+
if not backend:
225+
return
226+
227+
mapping = {
228+
"vllm_nccl": ".nccl.DistributedNccl",
229+
"vllm_hccl": ".hccl.DistributedHccl",
230+
}
231+
if backend not in mapping:
232+
raise ValueError(f"Unsupported custom backend: {backend}")
233+
234+
module_path, class_name = mapping[backend].rsplit(".", 1)
235+
module = importlib.import_module(module_path, "checkpoint_engine.distributed")
236+
backend_class = getattr(module, class_name)
237+
_BACKEND_INSTANCE = backend_class()
238+
239+
240+
def init_process_group(
241+
rank: int,
242+
world_size: int,
243+
store: torch_dist.TCPStore,
244+
**kwargs,
245+
):
246+
_BACKEND_INSTANCE.init_process_group(rank, world_size, store, **kwargs)
247+
248+
249+
def destroy_process_group(group: DistributedProcessGroup | None = None):
250+
_BACKEND_INSTANCE.destroy_process_group(group)
251+
252+
253+
def is_initialized() -> bool:
254+
return _BACKEND_INSTANCE.is_initialized()
255+
256+
257+
def all_gather_object(
258+
object_list: list[Any],
259+
obj: Any,
260+
group: DistributedProcessGroup | None = None,
261+
):
262+
_BACKEND_INSTANCE.all_gather_object(object_list, obj, group)
263+
264+
265+
def all_reduce(
266+
tensor: torch.Tensor,
267+
op: torch_dist.ReduceOp.RedOpType = torch_dist.ReduceOp.SUM,
268+
group: DistributedProcessGroup | None = None,
269+
**kwargs,
270+
):
271+
_BACKEND_INSTANCE.all_reduce(tensor, op, group, **kwargs)
272+
273+
274+
def broadcast(
275+
tensor: torch.Tensor,
276+
src: int = 0,
277+
group: DistributedProcessGroup | None = None,
278+
**kwargs,
279+
):
280+
_BACKEND_INSTANCE.broadcast(tensor, src, group, **kwargs)
281+
282+
283+
def barrier(group: DistributedProcessGroup | None = None, **kwargs):
284+
_BACKEND_INSTANCE.barrier(group, **kwargs)
285+
286+
287+
def new_group(ranks: list[int], **kwargs) -> DistributedProcessGroup | None:
288+
return _BACKEND_INSTANCE.new_group(ranks, **kwargs)

0 commit comments

Comments
 (0)