Skip to content

Commit a09c04b

Browse files
namgyu-younfacebook-github-bot
authored andcommitted
Update common utility function : distribute_buffer_sizes (facebookresearch#96)
Summary: This is the rebased PR for facebookresearch#94 (already approved) - close : facebookresearch#93 Pull Request resolved: facebookresearch#96 Reviewed By: anana10c Differential Revision: D71830614 Pulled By: tsunghsienlee fbshipit-source-id: 4637536d526a69e5e4631f6493658cc979be3ae4
1 parent 8e25111 commit a09c04b

5 files changed

+128
-214
lines changed

distributed_shampoo/utils/shampoo_ddp_distributor.py

+4-72
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@
77
88
"""
99

10-
import heapq
1110
import logging
12-
import operator
1311
from functools import partial
1412
from itertools import islice
1513
from typing import Any
@@ -26,6 +24,7 @@
2624
from distributed_shampoo.utils.shampoo_distributor import DistributorInterface
2725
from distributed_shampoo.utils.shampoo_utils import (
2826
compress_list,
27+
distribute_buffer_sizes,
2928
generate_pairwise_indices,
3029
get_dtype_size,
3130
)
@@ -99,11 +98,12 @@ def __init__(
9998
group_rank: int = dist.get_rank(group=self._dist_group)
10099

101100
# Assign ranks to blocks with their respective buffer size.
102-
buffer_size_ranks = self._distribute_buffer_sizes(
101+
buffer_size_ranks = distribute_buffer_sizes(
103102
buffer_sizes=tuple(
104103
blocked_param.numel() * get_dtype_size(communication_dtype)
105104
for blocked_param in self._global_blocked_params
106-
)
105+
),
106+
group_size=self._group_size,
107107
)
108108

109109
global_block_info_list = self._construct_global_block_info_list(
@@ -195,74 +195,6 @@ def update_params(
195195
self._global_masked_dist_blocked_buffers,
196196
)
197197

198-
def _distribute_buffer_sizes(
199-
self,
200-
buffer_sizes: tuple[int, ...],
201-
) -> tuple[tuple[int, int], ...]:
202-
"""Distribute given buffer sizes across ranks in a group.
203-
204-
Buffer sizes will be rounded up for memory allocation. Buffers are distributed such that
205-
total buffer sizes of each rank are as even as possible. This is currently performed
206-
using a greedy algorithm. We do not currently consider computational cost
207-
or kernel launching overheads.
208-
209-
Note: A better distribution strategy should try to minimize the delta of buffer sizes
210-
between the most and the least allocated groups.
211-
212-
Args:
213-
buffer_sizes (tuple[int, ...]): Buffer sizes of blocks to be distributed.
214-
215-
Returns:
216-
buffer_size_ranks (tuple[tuple[int, int], ...]): A list of tuples containing the
217-
buffer size for each block and its assigned rank.
218-
219-
Example:
220-
Assuming ALIGNMENT_BYTES = 64, given buffer_sizes = [128, 64, 500, 256], group_size = 2
221-
-> buffer_size_ranks = [(128, 1), (64, 1), (512, 0), (256, 1)]
222-
223-
"""
224-
225-
ALIGNMENT_BYTES = (
226-
64 # necessary for determining buffer size, possibly hardware-dependent
227-
)
228-
229-
# Convert each of buffer_sizes into smallest multiple of ALIGNMENT_BYTES that is >= buffer size.
230-
aligned_buffer_sizes = [
231-
(buffer_size + ALIGNMENT_BYTES - 1) // ALIGNMENT_BYTES * ALIGNMENT_BYTES
232-
for buffer_size in buffer_sizes
233-
]
234-
buffer_size_ranks = [(-1, -1)] * len(buffer_sizes)
235-
allocated_buffer_sizes = [
236-
(0, group_index) for group_index in range(self._group_size)
237-
]
238-
heapq.heapify(allocated_buffer_sizes)
239-
240-
for index, aligned_buffer_size in sorted(
241-
enumerate(aligned_buffer_sizes),
242-
key=operator.itemgetter(1),
243-
reverse=True,
244-
):
245-
# Greedily find the group with the least allocated buffer size and its group index
246-
# in order to allocate buffers on that group.
247-
(
248-
min_allocated_buffer_size,
249-
min_allocated_buffer_size_group_index,
250-
) = heapq.heappop(allocated_buffer_sizes)
251-
252-
heapq.heappush(
253-
allocated_buffer_sizes,
254-
(
255-
min_allocated_buffer_size + aligned_buffer_size,
256-
min_allocated_buffer_size_group_index,
257-
),
258-
)
259-
buffer_size_ranks[index] = (
260-
aligned_buffer_size,
261-
min_allocated_buffer_size_group_index,
262-
)
263-
264-
return tuple(buffer_size_ranks)
265-
266198
@torch.no_grad()
267199
def _construct_global_block_info_list(
268200
self, group_source_ranks: tuple[int, ...]

distributed_shampoo/utils/shampoo_hsdp_distributor.py

+4-71
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@
77
88
"""
99

10-
import heapq
1110
import logging
12-
import operator
1311
from functools import partial
1412
from itertools import islice
1513
from math import prod
@@ -29,6 +27,7 @@
2927
from distributed_shampoo.utils.shampoo_distributor import DistributorInterface
3028
from distributed_shampoo.utils.shampoo_utils import (
3129
compress_list,
30+
distribute_buffer_sizes,
3231
generate_pairwise_indices,
3332
get_dtype_size,
3433
merge_small_dims,
@@ -193,11 +192,12 @@ def __init__(
193192
comms_group_rank: int = dist.get_rank(self._comms_dist_group)
194193

195194
# Assign ranks to blocks with their respective buffer size.
196-
buffer_size_ranks = self._distribute_buffer_sizes(
195+
buffer_size_ranks = distribute_buffer_sizes(
197196
buffer_sizes=tuple(
198197
blocked_param.numel() * get_dtype_size(communication_dtype)
199198
for blocked_param in self._global_blocked_params
200-
)
199+
),
200+
group_size=self._dist_group_size,
201201
)
202202

203203
global_block_info_list = self._construct_global_block_info_list(
@@ -289,73 +289,6 @@ def update_params(
289289
self._global_masked_dist_blocked_buffers,
290290
)
291291

292-
def _distribute_buffer_sizes(
293-
self,
294-
buffer_sizes: tuple[int, ...],
295-
) -> tuple[tuple[int, int], ...]:
296-
"""Distribute given buffer sizes across ranks in a group.
297-
298-
Buffer sizes will be rounded up for memory allocation. Buffers are distributed such that
299-
total buffer sizes of each rank are as even as possible. This is currently performed
300-
using a greedy algorithm. We do not currently consider computational cost
301-
or kernel launching overheads.
302-
303-
Note: A better distribution strategy should try to minimize the delta of buffer sizes
304-
between the most and the least allocated groups.
305-
306-
Args:
307-
buffer_sizes (tuple[int, ...]): Buffer sizes of blocks to be distributed.
308-
309-
Returns:
310-
buffer_size_ranks (tuple[tuple[int, int], ...]): A list of tuples containing the
311-
buffer size for each block and its assigned rank.
312-
313-
Example:
314-
Assuming ALIGNMENT_BYTES = 64, given buffer_sizes = [128, 64, 500, 256], group_size = 2
315-
-> buffer_size_ranks = [(128, 1), (64, 1), (512, 0), (256, 1)]
316-
317-
"""
318-
ALIGNMENT_BYTES = (
319-
64 # necessary for determining buffer size, possibly hardware-dependent
320-
)
321-
322-
# Convert each of buffer_sizes into smallest multiple of ALIGNMENT_BYTES that is >= buffer size.
323-
aligned_buffer_sizes = [
324-
(buffer_size + ALIGNMENT_BYTES - 1) // ALIGNMENT_BYTES * ALIGNMENT_BYTES
325-
for buffer_size in buffer_sizes
326-
]
327-
buffer_size_ranks = [(-1, -1)] * len(buffer_sizes)
328-
allocated_buffer_sizes = [
329-
(0, group_index) for group_index in range(self._dist_group_size)
330-
]
331-
heapq.heapify(allocated_buffer_sizes)
332-
333-
for index, aligned_buffer_size in sorted(
334-
enumerate(aligned_buffer_sizes),
335-
key=operator.itemgetter(1),
336-
reverse=True,
337-
):
338-
# Greedily find the group with the least allocated buffer size and its group index
339-
# in order to allocate buffers on that group.
340-
(
341-
min_allocated_buffer_size,
342-
min_allocated_buffer_size_group_index,
343-
) = heapq.heappop(allocated_buffer_sizes)
344-
345-
heapq.heappush(
346-
allocated_buffer_sizes,
347-
(
348-
min_allocated_buffer_size + aligned_buffer_size,
349-
min_allocated_buffer_size_group_index,
350-
),
351-
)
352-
buffer_size_ranks[index] = (
353-
aligned_buffer_size,
354-
min_allocated_buffer_size_group_index,
355-
)
356-
357-
return tuple(buffer_size_ranks)
358-
359292
def _construct_composable_block_ids(
360293
self,
361294
param_index: int,

distributed_shampoo/utils/shampoo_hybrid_shard_distributor.py

+4-71
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@
77
88
"""
99

10-
import heapq
1110
import logging
12-
import operator
1311
from functools import partial
1412
from itertools import islice
1513
from typing import Any, Iterable
@@ -25,6 +23,7 @@
2523
from distributed_shampoo.utils.shampoo_distributor import DistributorInterface
2624
from distributed_shampoo.utils.shampoo_utils import (
2725
compress_list,
26+
distribute_buffer_sizes,
2827
generate_pairwise_indices,
2928
get_dtype_size,
3029
)
@@ -179,11 +178,12 @@ def __init__(
179178
comms_group_rank: int = dist.get_rank(self._comms_dist_group)
180179

181180
# Assign ranks to blocks with their respective buffer size.
182-
buffer_size_ranks = self._distribute_buffer_sizes(
181+
buffer_size_ranks = distribute_buffer_sizes(
183182
buffer_sizes=tuple(
184183
blocked_param.numel() * get_dtype_size(communication_dtype)
185184
for blocked_param in self._global_blocked_params
186-
)
185+
),
186+
group_size=self._dist_group_size,
187187
)
188188

189189
global_block_info_list = self._construct_global_block_info_list(
@@ -293,73 +293,6 @@ def update_params(
293293
self._global_masked_dist_blocked_buffers,
294294
)
295295

296-
def _distribute_buffer_sizes(
297-
self,
298-
buffer_sizes: tuple[int, ...],
299-
) -> tuple[tuple[int, int], ...]:
300-
"""Distribute given buffer sizes across ranks in a group.
301-
302-
Buffer sizes will be rounded up for memory allocation. Buffers are distributed such that
303-
total buffer sizes of each rank are as even as possible. This is currently performed
304-
using a greedy algorithm. We do not currently consider computational cost
305-
or kernel launching overheads.
306-
307-
Note: A better distribution strategy should try to minimize the delta of buffer sizes
308-
between the most and the least allocated groups.
309-
310-
Args:
311-
buffer_sizes (tuple[int, ...]): Buffer sizes of blocks to be distributed.
312-
313-
Returns:
314-
buffer_size_ranks (tuple[tuple[int, int], ...]): A list of tuples containing the
315-
buffer size for each block and its assigned rank.
316-
317-
Example:
318-
Assuming ALIGNMENT_BYTES = 64, given buffer_sizes = [128, 64, 500, 256], group_size = 2
319-
-> buffer_size_ranks = [(128, 1), (64, 1), (512, 0), (256, 1)]
320-
321-
"""
322-
ALIGNMENT_BYTES = (
323-
64 # necessary for determining buffer size, possibly hardware-dependent
324-
)
325-
326-
# Convert each of buffer_sizes into smallest multiple of ALIGNMENT_BYTES that is >= buffer size.
327-
aligned_buffer_sizes = [
328-
(buffer_size + ALIGNMENT_BYTES - 1) // ALIGNMENT_BYTES * ALIGNMENT_BYTES
329-
for buffer_size in buffer_sizes
330-
]
331-
buffer_size_ranks = [(-1, -1)] * len(buffer_sizes)
332-
allocated_buffer_sizes = [
333-
(0, group_index) for group_index in range(self._dist_group_size)
334-
]
335-
heapq.heapify(allocated_buffer_sizes)
336-
337-
for index, aligned_buffer_size in sorted(
338-
enumerate(aligned_buffer_sizes),
339-
key=operator.itemgetter(1),
340-
reverse=True,
341-
):
342-
# Greedily find the group with the least allocated buffer size and its group index
343-
# in order to allocate buffers on that group.
344-
(
345-
min_allocated_buffer_size,
346-
min_allocated_buffer_size_group_index,
347-
) = heapq.heappop(allocated_buffer_sizes)
348-
349-
heapq.heappush(
350-
allocated_buffer_sizes,
351-
(
352-
min_allocated_buffer_size + aligned_buffer_size,
353-
min_allocated_buffer_size_group_index,
354-
),
355-
)
356-
buffer_size_ranks[index] = (
357-
aligned_buffer_size,
358-
min_allocated_buffer_size_group_index,
359-
)
360-
361-
return tuple(buffer_size_ranks)
362-
363296
def _construct_composable_block_ids(
364297
self,
365298
param_index: int,

0 commit comments

Comments
 (0)