-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Expand file tree
/
Copy pathtorch_collective.py
More file actions
219 lines (183 loc) · 8.56 KB
/
torch_collective.py
File metadata and controls
219 lines (183 loc) · 8.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
import datetime
import os
from typing import Any, Optional, Union
import torch
import torch.distributed as dist
from torch import Tensor
from typing_extensions import Self, override
from lightning.fabric.plugins.collectives.collective import Collective
from lightning.fabric.utilities.types import CollectibleGroup, RedOpType, ReduceOp
if dist.is_available():
from torch.distributed.constants import default_pg_timeout
else:
default_pg_timeout = datetime.timedelta(seconds=1800)
class TorchCollective(Collective):
"""Collective operations using `torch.distributed <https://pytorch.org/docs/stable/distributed.html>`__.
.. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature which is still in development.
"""
manages_default_group = False
addr_key = "MASTER_ADDR"
port_key = "MASTER_PORT"
def __init__(self) -> None:
if not dist.is_available():
raise RuntimeError("Torch distributed is not available.")
super().__init__()
@property
@override
def group(self) -> CollectibleGroup:
if self._group is None:
self._group = dist.GroupMember.WORLD
return super().group
@property
@override
def rank(self) -> int:
# local rank
return dist.get_rank(self.group) # type: ignore[arg-type]
@property
@override
def world_size(self) -> int:
return dist.get_world_size(self.group) # type: ignore[arg-type]
@override
def broadcast(self, tensor: Tensor, src: int) -> Tensor:
dist.broadcast(tensor, src, group=self.group) # type: ignore[arg-type]
return tensor
@override
def all_reduce(self, tensor: Tensor, op: Union[str, ReduceOp, RedOpType] = "sum") -> Tensor:
op = self._convert_to_native_op(op)
dist.all_reduce(tensor, op=op, group=self.group)
return tensor
@override
def reduce(self, tensor: Tensor, dst: int, op: Union[str, ReduceOp, RedOpType] = "sum") -> Tensor:
op = self._convert_to_native_op(op)
dist.reduce(tensor, dst, op=op, group=self.group) # type: ignore[arg-type]
return tensor
@override
def all_gather(self, tensor_list: list[Tensor], tensor: Tensor) -> list[Tensor]:
dist.all_gather(tensor_list, tensor, group=self.group)
return tensor_list
@override
def gather(self, tensor: Tensor, gather_list: list[Tensor], dst: int = 0) -> list[Tensor]:
dist.gather(tensor, gather_list, dst, group=self.group) # type: ignore[arg-type]
return gather_list
@override
def scatter(self, tensor: Tensor, scatter_list: list[Tensor], src: int = 0) -> Tensor:
dist.scatter(tensor, scatter_list, src, group=self.group) # type: ignore[arg-type]
return tensor
@override
def reduce_scatter(
self, output: Tensor, input_list: list[Tensor], op: Union[str, ReduceOp, RedOpType] = "sum"
) -> Tensor:
op = self._convert_to_native_op(op)
dist.reduce_scatter(output, input_list, op=op, group=self.group)
return output
@override
def all_to_all(self, output_tensor_list: list[Tensor], input_tensor_list: list[Tensor]) -> list[Tensor]:
dist.all_to_all(output_tensor_list, input_tensor_list, group=self.group)
return output_tensor_list
@override
def send(self, tensor: Tensor, dst: int, tag: int = 0) -> None:
dist.send(tensor, dst, tag=tag, group=self.group) # type: ignore[arg-type]
@override
def recv(self, tensor: Tensor, src: Optional[int] = None, tag: int = 0) -> Tensor:
dist.recv(tensor, src, tag=tag, group=self.group) # type: ignore[arg-type]
return tensor
def all_gather_object(self, object_list: list[Any], obj: Any) -> list[Any]:
dist.all_gather_object(object_list, obj, group=self.group)
return object_list
def broadcast_object_list(
self, object_list: list[Any], src: int, device: Optional[torch.device] = None
) -> list[Any]:
dist.broadcast_object_list(object_list, src, group=self.group, device=device) # type: ignore[arg-type]
return object_list
def gather_object(self, obj: Any, object_gather_list: list[Any], dst: int = 0) -> list[Any]:
dist.gather_object(obj, object_gather_list, dst, group=self.group) # type: ignore[arg-type]
return object_gather_list
def scatter_object_list(
self, scatter_object_output_list: list[Any], scatter_object_input_list: list[Any], src: int = 0
) -> list[Any]:
dist.scatter_object_list(scatter_object_output_list, scatter_object_input_list, src, group=self.group) # type: ignore[arg-type]
return scatter_object_output_list
@override
def barrier(self, device_ids: Optional[list[int]] = None) -> None:
if self.group == dist.GroupMember.NON_GROUP_MEMBER:
return
dist.barrier(group=self.group, device_ids=device_ids) # type: ignore[arg-type]
def monitored_barrier(self, timeout: Optional[datetime.timedelta] = None, wait_all_ranks: bool = False) -> None:
dist.monitored_barrier(group=self.group, timeout=timeout, wait_all_ranks=wait_all_ranks) # type: ignore[arg-type]
@override
def setup(self, main_address: Optional[str] = None, main_port: Optional[str] = None, **kwargs: Any) -> Self:
if self.is_initialized():
return self
# maybe set addr
setting_env = []
if main_address is not None and self.addr_key not in os.environ:
os.environ[self.addr_key] = main_address
setting_env.append(self.addr_key)
# maybe set port
if main_port is not None and self.port_key not in os.environ:
os.environ[self.port_key] = str(main_port)
setting_env.append(self.port_key)
# this will `init_group`
super().setup(**kwargs)
# set as a class attribute so any instance can know whether we initialized the default process group
TorchCollective.manages_default_group = True
# cleanup
for kenv in setting_env:
os.environ.pop(kenv, None)
return self
@override
def teardown(self) -> Self:
group_member = self.group != dist.GroupMember.NON_GROUP_MEMBER
super().teardown() # will destroy its own group
# try to destroy the default group. this should only be done by a group member to avoid race conditions,
# and only if the class is managing it
if (
group_member
and TorchCollective.manages_default_group
and (default_group := dist.GroupMember.WORLD) is not None # not destroyed already
and len(dist.distributed_c10d._pg_map) == 1 # only the default group is left
):
self.destroy_group(default_group)
TorchCollective.manages_default_group = False
elif TorchCollective.manages_default_group and dist.GroupMember.WORLD is None:
TorchCollective.manages_default_group = False
return self
@classmethod
@override
def is_available(cls) -> bool:
return dist.is_available()
@classmethod
@override
def is_initialized(cls) -> bool:
return cls.is_available() and dist.is_initialized()
@classmethod
@override
def init_group(cls, **kwargs: Any) -> None:
dist.init_process_group(**kwargs)
@classmethod
@override
def new_group(cls, **kwargs: Any) -> CollectibleGroup:
return dist.new_group(**kwargs)
@classmethod
@override
def destroy_group(cls, group: CollectibleGroup) -> None:
# can be called by all processes in the default group, group will be `object()` if they are not part of the
# current group
if group in dist.distributed_c10d._pg_map:
dist.destroy_process_group(group)
@classmethod
@override
def _convert_to_native_op(cls, op: Union[str, ReduceOp, RedOpType]) -> Union[ReduceOp, RedOpType]:
# `ReduceOp` is an empty shell for `RedOpType`, the latter being the actually returned class.
# For example, `ReduceOp.SUM` returns a `RedOpType.SUM`. the only exception is `RedOpType.PREMUL_SUM` where
# `ReduceOp` is still the desired class, but it's created via a special `_make_nccl_premul_sum` function
if isinstance(op, (ReduceOp, RedOpType)):
return op
if not isinstance(op, str):
raise ValueError(f"Unsupported op {op!r} of type {type(op).__name__}")
op = op.upper()
# `ReduceOp` should contain `RedOpType`'s members
value = getattr(ReduceOp, op, None)
if value is None:
raise ValueError(f"op {op!r} is not a member of `ReduceOp`")
return value