diff --git a/distributed_shampoo/utils/shampoo_ddp_distributor.py b/distributed_shampoo/utils/shampoo_ddp_distributor.py index c8f4613..054c64d 100644 --- a/distributed_shampoo/utils/shampoo_ddp_distributor.py +++ b/distributed_shampoo/utils/shampoo_ddp_distributor.py @@ -7,9 +7,7 @@ """ -import heapq import logging -import operator from functools import partial from itertools import islice from typing import Any @@ -26,6 +24,7 @@ from distributed_shampoo.utils.shampoo_distributor import DistributorInterface from distributed_shampoo.utils.shampoo_utils import ( compress_list, + distribute_buffer_sizes, generate_pairwise_indices, get_dtype_size, ) @@ -99,11 +98,12 @@ def __init__( group_rank: int = dist.get_rank(group=self._dist_group) # Assign ranks to blocks with their respective buffer size. - buffer_size_ranks = self._distribute_buffer_sizes( + buffer_size_ranks = distribute_buffer_sizes( buffer_sizes=tuple( blocked_param.numel() * get_dtype_size(communication_dtype) for blocked_param in self._global_blocked_params - ) + ), + group_size=self._group_size, ) global_block_info_list = self._construct_global_block_info_list( @@ -195,74 +195,6 @@ def update_params( self._global_masked_dist_blocked_buffers, ) - def _distribute_buffer_sizes( - self, - buffer_sizes: tuple[int, ...], - ) -> tuple[tuple[int, int], ...]: - """Distribute given buffer sizes across ranks in a group. - - Buffer sizes will be rounded up for memory allocation. Buffers are distributed such that - total buffer sizes of each rank are as even as possible. This is currently performed - using a greedy algorithm. We do not currently consider computational cost - or kernel launching overheads. - - Note: A better distribution strategy should try to minimize the delta of buffer sizes - between the most and the least allocated groups. - - Args: - buffer_sizes (tuple[int, ...]): Buffer sizes of blocks to be distributed. - - Returns: - buffer_size_ranks (tuple[tuple[int, int], ...]): A list of tuples containing the - buffer size for each block and its assigned rank. - - Example: - Assuming ALIGNMENT_BYTES = 64, given buffer_sizes = [128, 64, 500, 256], group_size = 2 - -> buffer_size_ranks = [(128, 1), (64, 1), (512, 0), (256, 1)] - - """ - - ALIGNMENT_BYTES = ( - 64 # necessary for determining buffer size, possibly hardware-dependent - ) - - # Convert each of buffer_sizes into smallest multiple of ALIGNMENT_BYTES that is >= buffer size. - aligned_buffer_sizes = [ - (buffer_size + ALIGNMENT_BYTES - 1) // ALIGNMENT_BYTES * ALIGNMENT_BYTES - for buffer_size in buffer_sizes - ] - buffer_size_ranks = [(-1, -1)] * len(buffer_sizes) - allocated_buffer_sizes = [ - (0, group_index) for group_index in range(self._group_size) - ] - heapq.heapify(allocated_buffer_sizes) - - for index, aligned_buffer_size in sorted( - enumerate(aligned_buffer_sizes), - key=operator.itemgetter(1), - reverse=True, - ): - # Greedily find the group with the least allocated buffer size and its group index - # in order to allocate buffers on that group. - ( - min_allocated_buffer_size, - min_allocated_buffer_size_group_index, - ) = heapq.heappop(allocated_buffer_sizes) - - heapq.heappush( - allocated_buffer_sizes, - ( - min_allocated_buffer_size + aligned_buffer_size, - min_allocated_buffer_size_group_index, - ), - ) - buffer_size_ranks[index] = ( - aligned_buffer_size, - min_allocated_buffer_size_group_index, - ) - - return tuple(buffer_size_ranks) - @torch.no_grad() def _construct_global_block_info_list( self, group_source_ranks: tuple[int, ...] diff --git a/distributed_shampoo/utils/shampoo_hsdp_distributor.py b/distributed_shampoo/utils/shampoo_hsdp_distributor.py index 5e4e010..71d82d8 100644 --- a/distributed_shampoo/utils/shampoo_hsdp_distributor.py +++ b/distributed_shampoo/utils/shampoo_hsdp_distributor.py @@ -7,9 +7,7 @@ """ -import heapq import logging -import operator from functools import partial from itertools import islice from math import prod @@ -29,6 +27,7 @@ from distributed_shampoo.utils.shampoo_distributor import DistributorInterface from distributed_shampoo.utils.shampoo_utils import ( compress_list, + distribute_buffer_sizes, generate_pairwise_indices, get_dtype_size, merge_small_dims, @@ -193,11 +192,12 @@ def __init__( comms_group_rank: int = dist.get_rank(self._comms_dist_group) # Assign ranks to blocks with their respective buffer size. - buffer_size_ranks = self._distribute_buffer_sizes( + buffer_size_ranks = distribute_buffer_sizes( buffer_sizes=tuple( blocked_param.numel() * get_dtype_size(communication_dtype) for blocked_param in self._global_blocked_params - ) + ), + group_size=self._dist_group_size, ) global_block_info_list = self._construct_global_block_info_list( @@ -289,73 +289,6 @@ def update_params( self._global_masked_dist_blocked_buffers, ) - def _distribute_buffer_sizes( - self, - buffer_sizes: tuple[int, ...], - ) -> tuple[tuple[int, int], ...]: - """Distribute given buffer sizes across ranks in a group. - - Buffer sizes will be rounded up for memory allocation. Buffers are distributed such that - total buffer sizes of each rank are as even as possible. This is currently performed - using a greedy algorithm. We do not currently consider computational cost - or kernel launching overheads. - - Note: A better distribution strategy should try to minimize the delta of buffer sizes - between the most and the least allocated groups. - - Args: - buffer_sizes (tuple[int, ...]): Buffer sizes of blocks to be distributed. - - Returns: - buffer_size_ranks (tuple[tuple[int, int], ...]): A list of tuples containing the - buffer size for each block and its assigned rank. - - Example: - Assuming ALIGNMENT_BYTES = 64, given buffer_sizes = [128, 64, 500, 256], group_size = 2 - -> buffer_size_ranks = [(128, 1), (64, 1), (512, 0), (256, 1)] - - """ - ALIGNMENT_BYTES = ( - 64 # necessary for determining buffer size, possibly hardware-dependent - ) - - # Convert each of buffer_sizes into smallest multiple of ALIGNMENT_BYTES that is >= buffer size. - aligned_buffer_sizes = [ - (buffer_size + ALIGNMENT_BYTES - 1) // ALIGNMENT_BYTES * ALIGNMENT_BYTES - for buffer_size in buffer_sizes - ] - buffer_size_ranks = [(-1, -1)] * len(buffer_sizes) - allocated_buffer_sizes = [ - (0, group_index) for group_index in range(self._dist_group_size) - ] - heapq.heapify(allocated_buffer_sizes) - - for index, aligned_buffer_size in sorted( - enumerate(aligned_buffer_sizes), - key=operator.itemgetter(1), - reverse=True, - ): - # Greedily find the group with the least allocated buffer size and its group index - # in order to allocate buffers on that group. - ( - min_allocated_buffer_size, - min_allocated_buffer_size_group_index, - ) = heapq.heappop(allocated_buffer_sizes) - - heapq.heappush( - allocated_buffer_sizes, - ( - min_allocated_buffer_size + aligned_buffer_size, - min_allocated_buffer_size_group_index, - ), - ) - buffer_size_ranks[index] = ( - aligned_buffer_size, - min_allocated_buffer_size_group_index, - ) - - return tuple(buffer_size_ranks) - def _construct_composable_block_ids( self, param_index: int, diff --git a/distributed_shampoo/utils/shampoo_hybrid_shard_distributor.py b/distributed_shampoo/utils/shampoo_hybrid_shard_distributor.py index 42f3422..1098f4f 100644 --- a/distributed_shampoo/utils/shampoo_hybrid_shard_distributor.py +++ b/distributed_shampoo/utils/shampoo_hybrid_shard_distributor.py @@ -7,9 +7,7 @@ """ -import heapq import logging -import operator from functools import partial from itertools import islice from typing import Any, Iterable @@ -25,6 +23,7 @@ from distributed_shampoo.utils.shampoo_distributor import DistributorInterface from distributed_shampoo.utils.shampoo_utils import ( compress_list, + distribute_buffer_sizes, generate_pairwise_indices, get_dtype_size, ) @@ -179,11 +178,12 @@ def __init__( comms_group_rank: int = dist.get_rank(self._comms_dist_group) # Assign ranks to blocks with their respective buffer size. - buffer_size_ranks = self._distribute_buffer_sizes( + buffer_size_ranks = distribute_buffer_sizes( buffer_sizes=tuple( blocked_param.numel() * get_dtype_size(communication_dtype) for blocked_param in self._global_blocked_params - ) + ), + group_size=self._dist_group_size, ) global_block_info_list = self._construct_global_block_info_list( @@ -293,73 +293,6 @@ def update_params( self._global_masked_dist_blocked_buffers, ) - def _distribute_buffer_sizes( - self, - buffer_sizes: tuple[int, ...], - ) -> tuple[tuple[int, int], ...]: - """Distribute given buffer sizes across ranks in a group. - - Buffer sizes will be rounded up for memory allocation. Buffers are distributed such that - total buffer sizes of each rank are as even as possible. This is currently performed - using a greedy algorithm. We do not currently consider computational cost - or kernel launching overheads. - - Note: A better distribution strategy should try to minimize the delta of buffer sizes - between the most and the least allocated groups. - - Args: - buffer_sizes (tuple[int, ...]): Buffer sizes of blocks to be distributed. - - Returns: - buffer_size_ranks (tuple[tuple[int, int], ...]): A list of tuples containing the - buffer size for each block and its assigned rank. - - Example: - Assuming ALIGNMENT_BYTES = 64, given buffer_sizes = [128, 64, 500, 256], group_size = 2 - -> buffer_size_ranks = [(128, 1), (64, 1), (512, 0), (256, 1)] - - """ - ALIGNMENT_BYTES = ( - 64 # necessary for determining buffer size, possibly hardware-dependent - ) - - # Convert each of buffer_sizes into smallest multiple of ALIGNMENT_BYTES that is >= buffer size. - aligned_buffer_sizes = [ - (buffer_size + ALIGNMENT_BYTES - 1) // ALIGNMENT_BYTES * ALIGNMENT_BYTES - for buffer_size in buffer_sizes - ] - buffer_size_ranks = [(-1, -1)] * len(buffer_sizes) - allocated_buffer_sizes = [ - (0, group_index) for group_index in range(self._dist_group_size) - ] - heapq.heapify(allocated_buffer_sizes) - - for index, aligned_buffer_size in sorted( - enumerate(aligned_buffer_sizes), - key=operator.itemgetter(1), - reverse=True, - ): - # Greedily find the group with the least allocated buffer size and its group index - # in order to allocate buffers on that group. - ( - min_allocated_buffer_size, - min_allocated_buffer_size_group_index, - ) = heapq.heappop(allocated_buffer_sizes) - - heapq.heappush( - allocated_buffer_sizes, - ( - min_allocated_buffer_size + aligned_buffer_size, - min_allocated_buffer_size_group_index, - ), - ) - buffer_size_ranks[index] = ( - aligned_buffer_size, - min_allocated_buffer_size_group_index, - ) - - return tuple(buffer_size_ranks) - def _construct_composable_block_ids( self, param_index: int, diff --git a/distributed_shampoo/utils/shampoo_utils.py b/distributed_shampoo/utils/shampoo_utils.py index 16ab1d6..f2b1df9 100644 --- a/distributed_shampoo/utils/shampoo_utils.py +++ b/distributed_shampoo/utils/shampoo_utils.py @@ -7,7 +7,9 @@ """ +import heapq import math +import operator from collections.abc import Callable, Iterator, Sequence from functools import partial, reduce from itertools import accumulate, chain, compress, pairwise @@ -151,3 +153,69 @@ def __exit__( exc_tb: TracebackType | None, ) -> None: self._exit_method() + + +def distribute_buffer_sizes( + buffer_sizes: tuple[int, ...], + group_size: int, +) -> tuple[tuple[int, int], ...]: + """Distribute given buffer sizes across ranks in a group. + + Buffer sizes will be rounded up for memory allocation. Buffers are distributed such that + total buffer sizes of each rank are as even as possible. This is currently performed + using a greedy algorithm. We do not currently consider computational cost + or kernel launching overheads. + + Note: A better distribution strategy should try to minimize the delta of buffer sizes + between the most and the least allocated groups. + + Args: + buffer_sizes (tuple[int, ...]): Buffer sizes of blocks to be distributed. + group_size (int): Number of groups to distribute across. + + Returns: + buffer_size_ranks (tuple[tuple[int, int], ...]): A list of tuples containing the + buffer size for each block and its assigned rank. + + Example: + Assuming ALIGNMENT_BYTES = 64, given buffer_sizes = [128, 64, 500, 256], group_size = 2 + -> buffer_size_ranks = [(128, 1), (64, 1), (512, 0), (256, 1)] + """ + ALIGNMENT_BYTES = ( + 64 # necessary for determining buffer size, possibly hardware-dependent + ) + + # Convert each of buffer_sizes into smallest multiple of ALIGNMENT_BYTES that is >= buffer size. + aligned_buffer_sizes = [ + (buffer_size + ALIGNMENT_BYTES - 1) // ALIGNMENT_BYTES * ALIGNMENT_BYTES + for buffer_size in buffer_sizes + ] + buffer_size_ranks = [(-1, -1)] * len(buffer_sizes) + allocated_buffer_sizes = [(0, group_index) for group_index in range(group_size)] + heapq.heapify(allocated_buffer_sizes) + + for index, aligned_buffer_size in sorted( + enumerate(aligned_buffer_sizes), + key=operator.itemgetter(1), + reverse=True, + ): + # Greedily find the group with the least allocated buffer size and its group index + # in order to allocate buffers on that group. + ( + min_allocated_buffer_size, + min_allocated_buffer_size_group_index, + ) = heapq.heappop(allocated_buffer_sizes) + + heapq.heappush( + allocated_buffer_sizes, + ( + min_allocated_buffer_size + aligned_buffer_size, + min_allocated_buffer_size_group_index, + ), + ) + buffer_size_ranks[index] = ( + aligned_buffer_size, + min_allocated_buffer_size_group_index, + ) + + return tuple(buffer_size_ranks) diff --git a/distributed_shampoo/utils/tests/shampoo_utils_test.py b/distributed_shampoo/utils/tests/shampoo_utils_test.py index c7f5bde..10b96a3 100644 --- a/distributed_shampoo/utils/tests/shampoo_utils_test.py +++ b/distributed_shampoo/utils/tests/shampoo_utils_test.py @@ -15,6 +15,7 @@ from distributed_shampoo.utils.shampoo_utils import ( compress_list, + distribute_buffer_sizes, generate_pairwise_indices, get_dtype_size, merge_small_dims, @@ -185,3 +186,50 @@ def test_var(self) -> int: # Due to the invocation of test_class.exit(), the state of test_class.test_var should be -1. self.assertEqual(test_class.test_var, -1) + + +class DistributeBufferSizesTest(unittest.TestCase): + def test_distribute_buffer_sizes(self) -> None: + # Test case 1: Even distribution of buffer sizes + buffer_sizes = (128, 64, 500, 256) + group_size = 2 + expected_result = ( + (128, 1), + (64, 1), + (512, 0), + (256, 1), + ) + self.assertEqual( + distribute_buffer_sizes(buffer_sizes, group_size), expected_result + ) + + # Test case 2: Single group + buffer_sizes = (128, 64, 500, 256) + group_size = 1 + expected_result_single = ( + (128, 0), + (64, 0), + (512, 0), + (256, 0), + ) + self.assertEqual( + distribute_buffer_sizes(buffer_sizes, group_size), expected_result_single + ) + + # Test case 3: More groups than buffers + buffer_sizes_small = (128, 64) + group_size = 4 + expected_result_small: tuple[tuple[int, int], ...] = ((128, 0), (64, 1)) + self.assertEqual( + distribute_buffer_sizes(buffer_sizes_small, group_size), + expected_result_small, + ) + + # Test case 4: Empty buffer sizes + buffer_sizes_empty = () + group_size = 2 + expected_result_empty = () + self.assertEqual( + distribute_buffer_sizes(buffer_sizes_empty, group_size), + expected_result_empty, + )