Skip to content

Commit 3a3db95

Browse files
author
yexin
committed
add distributed abstraction
1 parent c3badb4 commit 3a3db95

File tree

5 files changed

+625
-409
lines changed

5 files changed

+625
-409
lines changed
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from .base import (
2+
Distributed,
3+
init_process_group,
4+
destroy_process_group,
5+
is_initialized,
6+
all_gather_object,
7+
all_reduce,
8+
broadcast,
9+
barrier,
10+
new_group,
11+
)
12+
13+
__all__ = [
14+
"Distributed",
15+
"init_process_group",
16+
"destroy_process_group",
17+
"is_initialized",
18+
"all_gather_object",
19+
"all_reduce",
20+
"broadcast",
21+
"barrier",
22+
"new_group",
23+
]
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
from abc import ABC, abstractmethod
2+
import io
3+
import pickle
4+
from datetime import timedelta
5+
from typing import Any, List
6+
import importlib
7+
8+
import torch
9+
from torch.distributed import ReduceOp
10+
11+
12+
class Distributed(ABC):
13+
@abstractmethod
14+
def init_process_group(
15+
self,
16+
host: str,
17+
port: int,
18+
rank: int,
19+
world_size: int,
20+
timeout: timedelta,
21+
):
22+
raise NotImplementedError
23+
24+
@abstractmethod
25+
def destroy_process_group(
26+
self,
27+
group,
28+
):
29+
raise NotImplementedError
30+
31+
@abstractmethod
32+
def is_initialized(self) -> bool:
33+
raise NotImplementedError
34+
35+
@abstractmethod
36+
def all_gather_object(
37+
self,
38+
object_list: list[Any],
39+
obj: Any,
40+
group,
41+
):
42+
raise NotImplementedError
43+
44+
@abstractmethod
45+
def all_reduce(
46+
self,
47+
tensor: torch.Tensor,
48+
op :ReduceOp,
49+
group,
50+
):
51+
raise NotImplementedError
52+
53+
@abstractmethod
54+
def broadcast(
55+
self,
56+
tensor: torch.Tensor,
57+
src: int,
58+
group,
59+
):
60+
raise NotImplementedError
61+
62+
@abstractmethod
63+
def barrier(
64+
self,
65+
group,
66+
):
67+
raise NotImplementedError
68+
69+
@abstractmethod
70+
def new_group(
71+
self,
72+
ranks: list[int],
73+
):
74+
raise NotImplementedError
75+
76+
77+
# specific device instance
78+
_BACKEND_INSTANCE = None
79+
80+
_pickler = pickle.Pickler
81+
_unpickler = pickle.Unpickler
82+
83+
84+
def _object_to_tensor(obj, device):
85+
f = io.BytesIO()
86+
_pickler(f).dump(obj)
87+
byte_storage = torch.ByteStorage._from_buffer(f.getvalue())
88+
byte_tensor = torch.ByteTensor(byte_storage).to(device)
89+
local_size = torch.LongTensor([byte_tensor.numel()]).to(device)
90+
return byte_tensor, local_size
91+
92+
93+
def _tensor_to_object(tensor, tensor_size):
94+
tensor = tensor.cpu()
95+
buf = tensor.numpy().tobytes()[:tensor_size]
96+
return _unpickler(io.BytesIO(buf)).load()
97+
98+
99+
def _flatten_for_scatter_gather(tensor_list, copy=False):
100+
if not tensor_list:
101+
raise RuntimeError("Received an empty list.")
102+
t = tensor_list[0]
103+
buffer_shape = [len(tensor_list)] + list(t.shape)
104+
105+
buffer = torch.empty(tuple(buffer_shape), dtype=t.dtype, device=t.device)
106+
if copy:
107+
for i, tensor in enumerate(tensor_list):
108+
buffer[i].copy_(tensor)
109+
return buffer
110+
111+
112+
def _common_all_gather_object(comm, device, world_size, object_list, object):
113+
input_tensor, local_size = _object_to_tensor(object, device)
114+
object_sizes_tensor = torch.empty(world_size, dtype=torch.long, device=device)
115+
comm.all_gather(object_sizes_tensor, local_size)
116+
object_size_list = [object_sizes_tensor[i].unsqueeze(dim=0) for i in range(world_size)]
117+
max_object_size = int(max(object_size_list).item())
118+
input_tensor.resize_(max_object_size)
119+
coalesced_output_tensor = torch.empty(
120+
max_object_size * world_size, dtype=torch.uint8, device=device
121+
)
122+
123+
comm.all_gather(coalesced_output_tensor, input_tensor)
124+
output_tensors = [
125+
coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)]
126+
for i in range(world_size)
127+
]
128+
for i, tensor in enumerate(output_tensors):
129+
tensor = tensor.type(torch.uint8)
130+
tensor_size = object_size_list[i]
131+
object_list[i] = _tensor_to_object(tensor, tensor_size)
132+
133+
134+
def init_process_group(
135+
host: str,
136+
port: int,
137+
rank: int,
138+
world_size: int,
139+
device_type: str,
140+
timeout: timedelta = timedelta(seconds=300),
141+
):
142+
global _BACKEND_INSTANCE
143+
144+
mapping = {
145+
"cuda": ".nccl.DistributedNccl",
146+
"npu": ".hccl.DistributedHccl",
147+
}
148+
149+
if device_type not in mapping:
150+
raise ValueError(f"Unsupported device type: {device_type}")
151+
152+
module_path, class_name = mapping[device_type].rsplit(".", 1)
153+
module = importlib.import_module(module_path, ".checkpoint_engine.distributed")
154+
backend_class = getattr(module, class_name)
155+
156+
_BACKEND_INSTANCE = backend_class()
157+
_BACKEND_INSTANCE.init_process_group(host, port, rank, world_size, timeout)
158+
159+
160+
def destroy_process_group(group=None):
161+
if _BACKEND_INSTANCE is None:
162+
raise RuntimeError("distribute module not initialized")
163+
_BACKEND_INSTANCE.destroy_process_group(group)
164+
165+
166+
def is_initialized() -> bool:
167+
if _BACKEND_INSTANCE is None:
168+
return False
169+
_BACKEND_INSTANCE.is_initialized()
170+
171+
def all_gather_object(
172+
object_list: list[Any],
173+
obj: Any,
174+
group=None,
175+
):
176+
if _BACKEND_INSTANCE is None:
177+
raise RuntimeError("distribute module not initialized")
178+
_BACKEND_INSTANCE.all_gather_object(object_list, obj, group)
179+
180+
181+
def all_reduce(
182+
tensor: torch.Tensor,
183+
op=ReduceOp.SUM,
184+
group=None,
185+
):
186+
if _BACKEND_INSTANCE is None:
187+
raise RuntimeError("distribute module not initialized")
188+
_BACKEND_INSTANCE.all_reduce(tensor, op, group)
189+
190+
191+
def broadcast(
192+
tensor: torch.Tensor,
193+
src= None,
194+
group=None,
195+
):
196+
if _BACKEND_INSTANCE is None:
197+
raise RuntimeError("distribute module not initialized")
198+
_BACKEND_INSTANCE.all_reduce(tensor, src, group)
199+
200+
201+
def barrier(group=None):
202+
if _BACKEND_INSTANCE is None:
203+
raise RuntimeError("distribute module not initialized")
204+
_BACKEND_INSTANCE.barrier(group)
205+
206+
207+
def new_group(ranks: list[int]):
208+
if _BACKEND_INSTANCE is None:
209+
raise RuntimeError("distribute module not initialized")
210+
_BACKEND_INSTANCE.new_group(ranks)

0 commit comments

Comments
 (0)