-
Notifications
You must be signed in to change notification settings - Fork 44
Uniform uniform distribute_buffer_sizes for calculating buff size #94
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,7 +7,9 @@ | |
|
||
""" | ||
|
||
import heapq | ||
import math | ||
import operator | ||
from collections.abc import Callable, Iterator, Sequence | ||
from functools import partial, reduce | ||
from itertools import accumulate, chain, compress, pairwise | ||
|
@@ -151,3 +153,69 @@ def __exit__( | |
exc_tb: TracebackType | None, | ||
) -> None: | ||
self._exit_method() | ||
|
||
|
||
def distribute_buffer_sizes( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We require 100% test coverage, and this function is not tested if we moved to here. https://github.com/facebookresearch/optimizers/blob/main/distributed_shampoo/utils/tests/shampoo_utils_test.py is the file you could add your tests. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Then may I add some test module in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yes, because There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for your review. I fixed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Codes look good to me, and could you fix the type errors listed in https://github.com/facebookresearch/optimizers/actions/runs/14051552340/job/39345812119?pr=94? You might want to rebase your repo first. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I tested GitHub Actions using But it failed because But I am not certain about this change because:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see, thanks. This appears to be a compatibility issue with The recommended workflow is to use the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
In my thought, two tools are useful in the following case:
So I want to carefully suggest this update would be helpful. But as you mentioned, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you want you can create a separate PR to ensure compatibility with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks. I will open PR after this. @tsunghsienlee could you approve and merge #96 ? I opened rebased PR for this. |
||
buffer_sizes: tuple[int, ...], | ||
group_size: int, | ||
) -> tuple[tuple[int, int], ...]: | ||
"""Distribute given buffer sizes across ranks in a group. | ||
|
||
Buffer sizes will be rounded up for memory allocation. Buffers are distributed such that | ||
total buffer sizes of each rank are as even as possible. This is currently performed | ||
using a greedy algorithm. We do not currently consider computational cost | ||
or kernel launching overheads. | ||
|
||
Note: A better distribution strategy should try to minimize the delta of buffer sizes | ||
between the most and the least allocated groups. | ||
|
||
Args: | ||
buffer_sizes (tuple[int, ...]): Buffer sizes of blocks to be distributed. | ||
group_size (int): Number of groups to distribute across. | ||
|
||
Returns: | ||
buffer_size_ranks (tuple[tuple[int, int], ...]): A list of tuples containing the | ||
buffer size for each block and its assigned rank. | ||
|
||
Example: | ||
Assuming ALIGNMENT_BYTES = 64, given buffer_sizes = [128, 64, 500, 256], group_size = 2 | ||
-> buffer_size_ranks = [(128, 1), (64, 1), (512, 0), (256, 1)] | ||
""" | ||
ALIGNMENT_BYTES = ( | ||
64 # necessary for determining buffer size, possibly hardware-dependent | ||
) | ||
|
||
# Convert each of buffer_sizes into smallest multiple of ALIGNMENT_BYTES that is >= buffer size. | ||
aligned_buffer_sizes = [ | ||
(buffer_size + ALIGNMENT_BYTES - 1) // ALIGNMENT_BYTES * ALIGNMENT_BYTES | ||
for buffer_size in buffer_sizes | ||
] | ||
buffer_size_ranks = [(-1, -1)] * len(buffer_sizes) | ||
allocated_buffer_sizes = [(0, group_index) for group_index in range(group_size)] | ||
heapq.heapify(allocated_buffer_sizes) | ||
|
||
for index, aligned_buffer_size in sorted( | ||
enumerate(aligned_buffer_sizes), | ||
key=operator.itemgetter(1), | ||
reverse=True, | ||
): | ||
# Greedily find the group with the least allocated buffer size and its group index | ||
# in order to allocate buffers on that group. | ||
( | ||
min_allocated_buffer_size, | ||
min_allocated_buffer_size_group_index, | ||
) = heapq.heappop(allocated_buffer_sizes) | ||
|
||
heapq.heappush( | ||
allocated_buffer_sizes, | ||
( | ||
min_allocated_buffer_size + aligned_buffer_size, | ||
min_allocated_buffer_size_group_index, | ||
), | ||
) | ||
buffer_size_ranks[index] = ( | ||
aligned_buffer_size, | ||
min_allocated_buffer_size_group_index, | ||
) | ||
|
||
return tuple(buffer_size_ranks) |
Uh oh!
There was an error while loading. Please reload this page.