-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathall_reduce_bucketer.py
More file actions
182 lines (153 loc) · 7.11 KB
/
all_reduce_bucketer.py
File metadata and controls
182 lines (153 loc) · 7.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import functools
import os
from typing import Callable, Dict, List, Optional, Tuple
import torch
from torch import Tensor
import torch.distributed as dist
from torch.distributed import ProcessGroup
# TODO: Remove the toggle-enable_nccl_base_collectives when github open issue #801 is resolved.
if os.getenv("ENABLE_NCCL_BASE_COLLECTIVES", "1") == "0":
enable_nccl_base_collectives = False
else:
enable_nccl_base_collectives = True
class Bucket:
def __init__(self, data: Tensor, group: ProcessGroup):
self.data = data
self.group = group
self.offset = 0
self.callbacks: List[Callable] = []
def flush(self) -> None:
"""Flush content of the bucket."""
if self.offset == 0:
assert len(self.callbacks) == 0
return
dist.all_reduce(self.data[:self.offset], group=self.group)
# execute post-reduction callbacks
for callback_fn in self.callbacks:
callback_fn()
# reuse input bucket but allocate a fresh output shard
self.data[: self.offset].zero_()
self.offset = 0
self.callbacks.clear()
def setup(self) -> None:
"""Setup the buffers if they are not allocated.
Using ``setup`` and ``teardown``, we can ensure that the bucket
buffers are only allocated during the backward pass, hence saving more
memory to other parts of the training process, such as the forward pass
for activation memory.
"""
for tensor in [self.data]:
if tensor.storage().size() == 0:
tensor.storage().resize_(tensor.size().numel())
def teardown(self) -> None:
"""Tear down the bucket by freeing the memory"""
assert self.offset == 0 and self.callbacks == [], "Incorrect call of teardown"
for tensor in [self.data]:
tensor.storage().resize_(0)
class AllReduceBucketer:
"""
Helper for bucketing multiple reduce-scatter operations on small tensors
into larger reduce-scatter ops to improve communication efficiency.
Usage::
bucketer = ReduceScatterBucketer()
bucketer.reduce_scatter_async(
small_tensors, callback_fn=lambda result: print("small")
)
bucketer.reduce_scatter_async(
big_tensors, callback_fn=lambda result: print("big")
)
bucketer.reduce_scatter_async(
more_small_tensors, callback_fn=lambda result: print("small2")
)
bucketer.flush() # callbacks only guaranteed to be called after flush()
# Example output (note that it is out of order, due to bucketing):
# big
# small
# small2
Args:
bucket_cap_mb (int, Optional): bucket size for communicating. Buckets
are sub-divided based on world_size. Values <= 0 disable bucketing.
"""
def __init__(self, bucket_cap_mb: int = 25):
self.bucket_cap_mb = bucket_cap_mb
self.buckets: Dict[Tuple[torch.dtype, torch.device, ProcessGroup], Bucket] = {}
@torch.no_grad()
def all_reduce_async(
self, input_tensor: Tensor, group: ProcessGroup, callback_fn: Optional[Callable] = None,
) -> None:
"""
Reduce-scatter a list of tensors asynchronously, so smaller reductions
can be bucketed together. The given callback (``callback_fn``) will be
called with the reduced result at some later time. Call ``flush()`` to
force all queued ops and callbacks to be executed.
Note that large inputs will be reduced immediately, and this function
may also flush the relevant bucket to make room for ``input_list``.
Args:
input_list (List[Tensor]): list of tensors to reduce-scatter. List
should contain ``group.size()`` tensors and each tensor should
have identical shape, dtype and device.
group (ProcessGroup): process group for reduction
callback_fn (Callable, Optional): callback function to call after
the reduction executes. Function will be called with a single
argument corresponding to the reduced result.
"""
world_size = group.size()
first_input = input_tensor
first_input_size = input_tensor.numel()
bucket_size = self._get_shard_size(input_tensor.element_size(), world_size)
if first_input_size > bucket_size:
# TODO: investigate how to avoid using torch.cat (because it seems to be slow for CPU tensors)
# input is too big to fit in the bucket, reduce-scatter directly
output = torch.zeros_like(input_tensor)
dist.all_reduce(input_tensor, group=group)
if callback_fn is not None:
callback_fn(input_tensor)
return
bucket = self._get_bucket(first_input, group)
if first_input_size > bucket.data.size(0) - bucket.offset:
# not enough space remaining in bucket, flush it now
bucket.flush()
# copy data from input_list into bucket
stacked_input = input_tensor.view(first_input_size)
offset = bucket.offset
bucket.data[offset : offset + first_input_size].copy_(stacked_input)
bucket.offset += first_input_size
# callback will be given the reduced result
if callback_fn is not None:
result_view = bucket.data[offset : offset + first_input_size].view_as(first_input)
bucket.callbacks.append(functools.partial(callback_fn, result_view))
@torch.no_grad()
def flush(self) -> None:
"""Reduce-scatter any partial buckets."""
for bucket in self.buckets.values():
bucket.flush()
@torch.no_grad()
def teardown(self) -> None:
"""Free buffers from all buckets."""
for bucket in self.buckets.values():
bucket.teardown()
@functools.lru_cache()
def _get_shard_size(self, element_size: int, num_shards: int) -> int:
if self.bucket_cap_mb <= 0: # Values <= 0 disable bucketing.
return 0
MB = 1024 * 1024
bucket_size = self.bucket_cap_mb * MB / element_size
return int(bucket_size)
def _get_bucket(self, tensor: Tensor, group: ProcessGroup) -> Bucket:
# TODO (Min): the `group` used here in the key is the object hash, not the content
# hash. That means if FSDP instances are initialized with different process groups,
# even when the group members are in fact the same, we end up creating different
# buckets here.
key = (tensor.dtype, tensor.device, group)
if key not in self.buckets:
# buckets are divided into world_size pieces, bucket.data shaped (world_size, shard_size)
world_size = group.size()
shard_size = self._get_shard_size(tensor.element_size(), world_size)
data = tensor.new_zeros(shard_size)
self.buckets[key] = Bucket(data, group)
self.buckets[key].setup()
return self.buckets[key]