forked from facebookresearch/optimizers
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathshampoo_utils.py
221 lines (170 loc) · 7.86 KB
/
shampoo_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
"""
Copyright (c) Meta Platforms, Inc. and affiliates.
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.
"""
import heapq
import math
import operator
from collections.abc import Callable, Iterator, Sequence
from functools import partial, reduce
from itertools import accumulate, chain, compress, pairwise
from types import TracebackType
from typing import Type, TypeVar
import torch
from torch import Tensor
def merge_small_dims(tensor_shape: Sequence[int], threshold: int) -> tuple[int, ...]:
"""Reshapes tensor by merging small dimensions.
Args:
tensor_shape (Sequence[int]): The shape of the tensor.
threshold (int): Threshold on the maximum size of each dimension.
Returns:
new_tensor_shape (tuple[int, ...]): New tensor shape.
"""
# Squeeze tensor shape to remove dimension with 1; if all dimensions are 1,
# then add a 1 to the tensor shape.
squeezed_tensor_shape = list(filter(lambda t: t != 1, tensor_shape)) or [1]
new_tensor_shape = [squeezed_tensor_shape[0]]
for next_tensor_shape in squeezed_tensor_shape[1:]:
if (new_dimension := new_tensor_shape[-1] * next_tensor_shape) <= threshold:
new_tensor_shape[-1] = new_dimension
else:
new_tensor_shape.append(next_tensor_shape)
return tuple(new_tensor_shape)
def multi_dim_split(tensor: Tensor, split_size: int) -> tuple[Tensor, ...]:
"""Chunks tensor across multiple dimensions based on splits.
Args:
tensor (Tensor): Gradient or tensor to split.
split_size (int): Size of a single chunk.
Returns:
split_tensors (tuple[Tensor, ...]): List of tensors.
"""
return reduce(
lambda split_tensors, dim: tuple(
s for t in split_tensors for s in torch.split(t, split_size, dim=dim)
),
range(tensor.dim()),
(tensor,),
)
CompressListType = TypeVar("CompressListType")
def compress_list(
complete_list: Sequence[CompressListType], selector: Sequence[bool]
) -> tuple[CompressListType, ...]:
"""Compresses sequence based on selector.
NOTE: Despite the name, this function can compress both lists and tuples, but will always return
a tuple in order to ensure downstream compatibility.
Args:
complete_list (Sequence[CompressListType]): Complete tuple of candidates.
selector (Sequence[bool]): Mask that is True if state is active, False otherwise.
Returns:
compressed_tuple (tuple[CompressListType, ...]): Compressed list of candidates based on selector.
"""
assert (
len(complete_list) == len(selector)
), f"Inconsistent lengths between complete_list {len(complete_list)} and selector {len(selector)}!"
return tuple(compress(complete_list, selector))
def get_dtype_size(dtype: torch.dtype) -> int:
"""Return the size (bytes) of a given data type."""
if dtype is torch.bool:
return 1
return math.ceil(
(torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits / 8.0
)
def generate_pairwise_indices(input_list: Sequence[int]) -> Iterator[tuple[int, int]]:
"""Generates accumulated pairwise indices for a given input list.
For example, if input_list = (1, 3, 2), then this will output [(0, 1), (1, 4), (4, 6)].
This is useful for generating interval indices for iterating through a list given the
number of blocks within each parameter.
Args:
input_list (Sequence[int]): A list of intergers specifying the number of elements within each partition.
Returns:
partition_indices: Iterator[tuple[int, int]]: An iterator containing pairs of indices which specify
the start and the ending indices of each partition specified in the input_list.
"""
return pairwise(accumulate(chain([0], input_list)))
ParameterizeEnterExitContextType = TypeVar("ParameterizeEnterExitContextType")
class ParameterizeEnterExitContext:
"""ParameterizeEnterExitContext is used for automatically invoking the enter and exit methods on the input within this context.
Args:
input_with_enter_exit_context (ParameterizeEnterExitContextType): Input whose state will be changed while entering and exiting the context by enter_method_caller and exit_method_caller and exit_method_caller respectively.
enter_method_caller (Callable[[ParameterizeEnterExitContextType], None]): Method caller for entering the context.
exit_method_caller (Callable[[ParameterizeEnterExitContextType], None]): Method caller for exiting the context.
"""
def __init__(
self,
input_with_enter_exit_context: ParameterizeEnterExitContextType,
enter_method_caller: Callable[[ParameterizeEnterExitContextType], None],
exit_method_caller: Callable[[ParameterizeEnterExitContextType], None],
) -> None:
self._enter_method: Callable[[], None] = partial(
enter_method_caller, input_with_enter_exit_context
)
self._exit_method: Callable[[], None] = partial(
exit_method_caller, input_with_enter_exit_context
)
def __enter__(self) -> "ParameterizeEnterExitContext":
self._enter_method()
return self
def __exit__(
self,
exc_type: Type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
self._exit_method()
def distribute_buffer_sizes(
buffer_sizes: tuple[int, ...],
group_size: int,
) -> tuple[tuple[int, int], ...]:
"""Distribute given buffer sizes across ranks in a group.
Buffer sizes will be rounded up for memory allocation. Buffers are distributed such that
total buffer sizes of each rank are as even as possible. This is currently performed
using a greedy algorithm. We do not currently consider computational cost
or kernel launching overheads.
Note: A better distribution strategy should try to minimize the delta of buffer sizes
between the most and the least allocated groups.
Args:
buffer_sizes (tuple[int, ...]): Buffer sizes of blocks to be distributed.
group_size (int): Number of groups to distribute across.
Returns:
buffer_size_ranks (tuple[tuple[int, int], ...]): A list of tuples containing the
buffer size for each block and its assigned rank.
Example:
Assuming ALIGNMENT_BYTES = 64, given buffer_sizes = [128, 64, 500, 256], group_size = 2
-> buffer_size_ranks = [(128, 1), (64, 1), (512, 0), (256, 1)]
"""
ALIGNMENT_BYTES = (
64 # necessary for determining buffer size, possibly hardware-dependent
)
# Convert each of buffer_sizes into smallest multiple of ALIGNMENT_BYTES that is >= buffer size.
aligned_buffer_sizes = [
(buffer_size + ALIGNMENT_BYTES - 1) // ALIGNMENT_BYTES * ALIGNMENT_BYTES
for buffer_size in buffer_sizes
]
buffer_size_ranks = [(-1, -1)] * len(buffer_sizes)
allocated_buffer_sizes = [(0, group_index) for group_index in range(group_size)]
heapq.heapify(allocated_buffer_sizes)
for index, aligned_buffer_size in sorted(
enumerate(aligned_buffer_sizes),
key=operator.itemgetter(1),
reverse=True,
):
# Greedily find the group with the least allocated buffer size and its group index
# in order to allocate buffers on that group.
(
min_allocated_buffer_size,
min_allocated_buffer_size_group_index,
) = heapq.heappop(allocated_buffer_sizes)
heapq.heappush(
allocated_buffer_sizes,
(
min_allocated_buffer_size + aligned_buffer_size,
min_allocated_buffer_size_group_index,
),
)
buffer_size_ranks[index] = (
aligned_buffer_size,
min_allocated_buffer_size_group_index,
)
return tuple(buffer_size_ranks)