Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update common utility function : distribute_buffer_sizes #96

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 4 additions & 72 deletions distributed_shampoo/utils/shampoo_ddp_distributor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@

"""

import heapq
import logging
import operator
from functools import partial
from itertools import islice
from typing import Any
Expand All @@ -26,6 +24,7 @@
from distributed_shampoo.utils.shampoo_distributor import DistributorInterface
from distributed_shampoo.utils.shampoo_utils import (
compress_list,
distribute_buffer_sizes,
generate_pairwise_indices,
get_dtype_size,
)
Expand Down Expand Up @@ -99,11 +98,12 @@ def __init__(
group_rank: int = dist.get_rank(group=self._dist_group)

# Assign ranks to blocks with their respective buffer size.
buffer_size_ranks = self._distribute_buffer_sizes(
buffer_size_ranks = distribute_buffer_sizes(
buffer_sizes=tuple(
blocked_param.numel() * get_dtype_size(communication_dtype)
for blocked_param in self._global_blocked_params
)
),
group_size=self._group_size,
)

global_block_info_list = self._construct_global_block_info_list(
Expand Down Expand Up @@ -195,74 +195,6 @@ def update_params(
self._global_masked_dist_blocked_buffers,
)

def _distribute_buffer_sizes(
self,
buffer_sizes: tuple[int, ...],
) -> tuple[tuple[int, int], ...]:
"""Distribute given buffer sizes across ranks in a group.

Buffer sizes will be rounded up for memory allocation. Buffers are distributed such that
total buffer sizes of each rank are as even as possible. This is currently performed
using a greedy algorithm. We do not currently consider computational cost
or kernel launching overheads.

Note: A better distribution strategy should try to minimize the delta of buffer sizes
between the most and the least allocated groups.

Args:
buffer_sizes (tuple[int, ...]): Buffer sizes of blocks to be distributed.

Returns:
buffer_size_ranks (tuple[tuple[int, int], ...]): A list of tuples containing the
buffer size for each block and its assigned rank.

Example:
Assuming ALIGNMENT_BYTES = 64, given buffer_sizes = [128, 64, 500, 256], group_size = 2
-> buffer_size_ranks = [(128, 1), (64, 1), (512, 0), (256, 1)]

"""

ALIGNMENT_BYTES = (
64 # necessary for determining buffer size, possibly hardware-dependent
)

# Convert each of buffer_sizes into smallest multiple of ALIGNMENT_BYTES that is >= buffer size.
aligned_buffer_sizes = [
(buffer_size + ALIGNMENT_BYTES - 1) // ALIGNMENT_BYTES * ALIGNMENT_BYTES
for buffer_size in buffer_sizes
]
buffer_size_ranks = [(-1, -1)] * len(buffer_sizes)
allocated_buffer_sizes = [
(0, group_index) for group_index in range(self._group_size)
]
heapq.heapify(allocated_buffer_sizes)

for index, aligned_buffer_size in sorted(
enumerate(aligned_buffer_sizes),
key=operator.itemgetter(1),
reverse=True,
):
# Greedily find the group with the least allocated buffer size and its group index
# in order to allocate buffers on that group.
(
min_allocated_buffer_size,
min_allocated_buffer_size_group_index,
) = heapq.heappop(allocated_buffer_sizes)

heapq.heappush(
allocated_buffer_sizes,
(
min_allocated_buffer_size + aligned_buffer_size,
min_allocated_buffer_size_group_index,
),
)
buffer_size_ranks[index] = (
aligned_buffer_size,
min_allocated_buffer_size_group_index,
)

return tuple(buffer_size_ranks)

@torch.no_grad()
def _construct_global_block_info_list(
self, group_source_ranks: tuple[int, ...]
Expand Down
75 changes: 4 additions & 71 deletions distributed_shampoo/utils/shampoo_hsdp_distributor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@

"""

import heapq
import logging
import operator
from functools import partial
from itertools import islice
from math import prod
Expand All @@ -29,6 +27,7 @@
from distributed_shampoo.utils.shampoo_distributor import DistributorInterface
from distributed_shampoo.utils.shampoo_utils import (
compress_list,
distribute_buffer_sizes,
generate_pairwise_indices,
get_dtype_size,
merge_small_dims,
Expand Down Expand Up @@ -193,11 +192,12 @@ def __init__(
comms_group_rank: int = dist.get_rank(self._comms_dist_group)

# Assign ranks to blocks with their respective buffer size.
buffer_size_ranks = self._distribute_buffer_sizes(
buffer_size_ranks = distribute_buffer_sizes(
buffer_sizes=tuple(
blocked_param.numel() * get_dtype_size(communication_dtype)
for blocked_param in self._global_blocked_params
)
),
group_size=self._dist_group_size,
)

global_block_info_list = self._construct_global_block_info_list(
Expand Down Expand Up @@ -289,73 +289,6 @@ def update_params(
self._global_masked_dist_blocked_buffers,
)

def _distribute_buffer_sizes(
self,
buffer_sizes: tuple[int, ...],
) -> tuple[tuple[int, int], ...]:
"""Distribute given buffer sizes across ranks in a group.

Buffer sizes will be rounded up for memory allocation. Buffers are distributed such that
total buffer sizes of each rank are as even as possible. This is currently performed
using a greedy algorithm. We do not currently consider computational cost
or kernel launching overheads.

Note: A better distribution strategy should try to minimize the delta of buffer sizes
between the most and the least allocated groups.

Args:
buffer_sizes (tuple[int, ...]): Buffer sizes of blocks to be distributed.

Returns:
buffer_size_ranks (tuple[tuple[int, int], ...]): A list of tuples containing the
buffer size for each block and its assigned rank.

Example:
Assuming ALIGNMENT_BYTES = 64, given buffer_sizes = [128, 64, 500, 256], group_size = 2
-> buffer_size_ranks = [(128, 1), (64, 1), (512, 0), (256, 1)]

"""
ALIGNMENT_BYTES = (
64 # necessary for determining buffer size, possibly hardware-dependent
)

# Convert each of buffer_sizes into smallest multiple of ALIGNMENT_BYTES that is >= buffer size.
aligned_buffer_sizes = [
(buffer_size + ALIGNMENT_BYTES - 1) // ALIGNMENT_BYTES * ALIGNMENT_BYTES
for buffer_size in buffer_sizes
]
buffer_size_ranks = [(-1, -1)] * len(buffer_sizes)
allocated_buffer_sizes = [
(0, group_index) for group_index in range(self._dist_group_size)
]
heapq.heapify(allocated_buffer_sizes)

for index, aligned_buffer_size in sorted(
enumerate(aligned_buffer_sizes),
key=operator.itemgetter(1),
reverse=True,
):
# Greedily find the group with the least allocated buffer size and its group index
# in order to allocate buffers on that group.
(
min_allocated_buffer_size,
min_allocated_buffer_size_group_index,
) = heapq.heappop(allocated_buffer_sizes)

heapq.heappush(
allocated_buffer_sizes,
(
min_allocated_buffer_size + aligned_buffer_size,
min_allocated_buffer_size_group_index,
),
)
buffer_size_ranks[index] = (
aligned_buffer_size,
min_allocated_buffer_size_group_index,
)

return tuple(buffer_size_ranks)

def _construct_composable_block_ids(
self,
param_index: int,
Expand Down
75 changes: 4 additions & 71 deletions distributed_shampoo/utils/shampoo_hybrid_shard_distributor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@

"""

import heapq
import logging
import operator
from functools import partial
from itertools import islice
from typing import Any, Iterable
Expand All @@ -25,6 +23,7 @@
from distributed_shampoo.utils.shampoo_distributor import DistributorInterface
from distributed_shampoo.utils.shampoo_utils import (
compress_list,
distribute_buffer_sizes,
generate_pairwise_indices,
get_dtype_size,
)
Expand Down Expand Up @@ -179,11 +178,12 @@ def __init__(
comms_group_rank: int = dist.get_rank(self._comms_dist_group)

# Assign ranks to blocks with their respective buffer size.
buffer_size_ranks = self._distribute_buffer_sizes(
buffer_size_ranks = distribute_buffer_sizes(
buffer_sizes=tuple(
blocked_param.numel() * get_dtype_size(communication_dtype)
for blocked_param in self._global_blocked_params
)
),
group_size=self._dist_group_size,
)

global_block_info_list = self._construct_global_block_info_list(
Expand Down Expand Up @@ -293,73 +293,6 @@ def update_params(
self._global_masked_dist_blocked_buffers,
)

def _distribute_buffer_sizes(
self,
buffer_sizes: tuple[int, ...],
) -> tuple[tuple[int, int], ...]:
"""Distribute given buffer sizes across ranks in a group.

Buffer sizes will be rounded up for memory allocation. Buffers are distributed such that
total buffer sizes of each rank are as even as possible. This is currently performed
using a greedy algorithm. We do not currently consider computational cost
or kernel launching overheads.

Note: A better distribution strategy should try to minimize the delta of buffer sizes
between the most and the least allocated groups.

Args:
buffer_sizes (tuple[int, ...]): Buffer sizes of blocks to be distributed.

Returns:
buffer_size_ranks (tuple[tuple[int, int], ...]): A list of tuples containing the
buffer size for each block and its assigned rank.

Example:
Assuming ALIGNMENT_BYTES = 64, given buffer_sizes = [128, 64, 500, 256], group_size = 2
-> buffer_size_ranks = [(128, 1), (64, 1), (512, 0), (256, 1)]

"""
ALIGNMENT_BYTES = (
64 # necessary for determining buffer size, possibly hardware-dependent
)

# Convert each of buffer_sizes into smallest multiple of ALIGNMENT_BYTES that is >= buffer size.
aligned_buffer_sizes = [
(buffer_size + ALIGNMENT_BYTES - 1) // ALIGNMENT_BYTES * ALIGNMENT_BYTES
for buffer_size in buffer_sizes
]
buffer_size_ranks = [(-1, -1)] * len(buffer_sizes)
allocated_buffer_sizes = [
(0, group_index) for group_index in range(self._dist_group_size)
]
heapq.heapify(allocated_buffer_sizes)

for index, aligned_buffer_size in sorted(
enumerate(aligned_buffer_sizes),
key=operator.itemgetter(1),
reverse=True,
):
# Greedily find the group with the least allocated buffer size and its group index
# in order to allocate buffers on that group.
(
min_allocated_buffer_size,
min_allocated_buffer_size_group_index,
) = heapq.heappop(allocated_buffer_sizes)

heapq.heappush(
allocated_buffer_sizes,
(
min_allocated_buffer_size + aligned_buffer_size,
min_allocated_buffer_size_group_index,
),
)
buffer_size_ranks[index] = (
aligned_buffer_size,
min_allocated_buffer_size_group_index,
)

return tuple(buffer_size_ranks)

def _construct_composable_block_ids(
self,
param_index: int,
Expand Down
Loading