forked from facebookresearch/optimizers
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathshampoo_hsdp_distributor.py
940 lines (807 loc) · 40.2 KB
/
shampoo_hsdp_distributor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
"""
Copyright (c) Meta Platforms, Inc. and affiliates.
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.
"""
import logging
from functools import partial
from itertools import islice
from math import prod
from typing import Any
import torch
from distributed_shampoo.shampoo_types import (
CommunicationDType,
FSDPParameterMetadata,
HSDPShampooConfig,
MAX_PRECONDITIONER_DIM,
PARAMS,
USE_MERGE_DIMS,
)
from distributed_shampoo.utils.shampoo_block_info import DDPBlockInfo
from distributed_shampoo.utils.shampoo_dist_utils import get_device_mesh
from distributed_shampoo.utils.shampoo_distributor import DistributorInterface
from distributed_shampoo.utils.shampoo_utils import (
compress_list,
distribute_buffer_sizes,
generate_pairwise_indices,
get_dtype_size,
merge_small_dims,
multi_dim_split,
)
from torch import distributed as dist, Tensor
from torch.distributed import tensor as dtensor
from torch.distributed.device_mesh import _mesh_resources
from torch.distributed.tensor import zeros as dtensor_zeros
from torch.nn import Parameter
logger: logging.Logger = logging.getLogger(__name__)
class HSDPDistributor(DistributorInterface):
"""HSDP Distributor class.
Handles split tensor block recovery of different parameters, then merging and blocking of
the tensor blocks, as well as distributing of the parameters at instantiation.
The constructor internally sets up `DeviceMesh` objects as necessary for distributing memory
and computation, so torch.distributed must be initialized in advance.
Unlike FSDPDistributor, HSDPDistributor requires the user to pass in a device mesh used for
HSDP. For example, suppose we have 48 GPUs and the HSDP group size is 8. Then:
HSDP Device Mesh with (Replicate, Shard) = (6, 8):
device_mesh = [[ 0, 1, 2, 3, 4, 5, 6, 7]
[ 8, 9, 10, 11, 12, 13, 14, 15]
[16, 17, 18, 19, 20, 21, 22, 23]
[24, 25, 26, 27, 28, 29, 30, 31]
[32, 33, 34, 35, 36, 37, 38, 39]
[40, 41, 42, 43, 44, 45, 46, 47]]
For example, if my device is rank 11, then:
device_mesh["replicate"] = [3, 11, 19, 27, 35, 43]
device_mesh["shard"] = [8, 9, 10, 11, 12, 13, 14, 15]
Since the parameters are sharded along the "shard" dimension, we would normally replicate the
computation along the "replicate" dimension. With HSDP Shampoo, we instead want to distribute
the computation and memory requirements across the "replicate" dimension of the original HSDP
device mesh.
For example, suppose that the num_trainers_per_group = 3. We want to form a (2, 3)-submesh on
the ranks [3, 11, 19, 27, 35, 43] (and similar).
HSDPDistributor 2D Sub-Mesh Example with (Replicate, Shard) = (2, 3):
submesh = [[ 3, 11, 19]
[27, 35, 43]]
In this case, optimizer states will live on different "replicate" meshes: {[3, 27], [11, 35],
[19, 43]}. In order to synchronize the optimizer step, we will communicate along the "shard"
mesh {[3, 11, 19], [27, 35, 43]}.
Args:
param_group (dict[str, Any]): Parameter group containing parameters.
distributed_config (HSDPShampooConfig): Configuration for HSDP Shampoo.
"""
def __init__(
self,
param_group: dict[str, Any],
distributed_config: HSDPShampooConfig,
) -> None:
self._param_to_metadata: dict[Parameter, FSDPParameterMetadata] = (
distributed_config.param_to_metadata
)
self._hsdp_device_mesh: torch.distributed.device_mesh.DeviceMesh = (
distributed_config.device_mesh
)
self._global_num_splits_per_param: tuple[int, ...] = ()
self._global_num_blocks_per_split_param: tuple[int, ...] = ()
super().__init__(param_group)
if not dist.is_initialized():
raise RuntimeError(
"HSDPDistributor needs torch.distributed to be initialized!"
)
# Construct global masked blocked parameters (which is DDP-specific).
self._global_masked_blocked_params: tuple[Tensor, ...] = (
self._global_blocked_params
)
# Check num_trainers_per_group and replicated group size.
# NOTE: If num_trainers_per_group = -1, then we use the replicated group size.
self._replicated_group_size: int = self._hsdp_device_mesh.size(0)
if not (
1
<= distributed_config.num_trainers_per_group
<= self._replicated_group_size
or distributed_config.num_trainers_per_group == -1
):
raise ValueError(
f"Invalid number of trainers per group: {distributed_config.num_trainers_per_group}. "
f"Must be between [1, {self._replicated_group_size}] or set to -1."
)
if distributed_config.num_trainers_per_group == -1:
logger.info(
f"Note that {distributed_config.num_trainers_per_group=}! Defaulting to replicated group size {self._replicated_group_size}."
)
elif (
not self._replicated_group_size % distributed_config.num_trainers_per_group
== 0
):
raise ValueError(
f"{distributed_config.num_trainers_per_group=} must divide {self._replicated_group_size=}!"
)
# Group size for distributing computation / memory requirements.
self._dist_group_size: int = (
distributed_config.num_trainers_per_group
if distributed_config.num_trainers_per_group != -1
else self._replicated_group_size
)
# Create flag for distributing parameters instead of search directions.
self._communicate_params: bool = distributed_config.communicate_params
# Determine communication type.
if distributed_config.communication_dtype == CommunicationDType.BF16:
communication_dtype = torch.bfloat16
elif distributed_config.communication_dtype == CommunicationDType.FP16:
communication_dtype = torch.float16
else:
assert distributed_config.communication_dtype in [
CommunicationDType.FP32,
CommunicationDType.DEFAULT,
]
communication_dtype = torch.float32
# Initialize _dist_group and _group_rank.
# Note that this requires initializing all process groups.
# Splits replicated ranks group into smaller groups of size self._dist_group_size.
# Instantiates this by using DeviceMesh.
ranks_in_all_replicated_groups = self._hsdp_device_mesh.mesh.T
for ranks_in_replicated_group in ranks_in_all_replicated_groups:
device_mesh = get_device_mesh(
device_type=self._hsdp_device_mesh.device_type,
mesh=tuple(
tuple(ranks_in_replicated_subgroup)
for ranks_in_replicated_subgroup in ranks_in_replicated_group.view(
-1, self._dist_group_size
).tolist()
),
mesh_dim_names=("replicate", "shard"),
)
if dist.get_rank() in ranks_in_replicated_group:
# NOTE: We want the process group in the device mesh that the current rank
# belongs to but solely along the "shard" dimension for communications.
#
# For example, if the current rank is 11, then I want the process group
# that contains the ranks [3, 11, 19].
self._comms_dist_group: dist.ProcessGroup = device_mesh.get_group(
"shard"
)
comms_group_rank: int = dist.get_rank(self._comms_dist_group)
# Assign ranks to blocks with their respective buffer size.
buffer_size_ranks = self._distribute_buffer_sizes(
buffer_sizes=tuple(
blocked_param.numel() * get_dtype_size(communication_dtype)
for blocked_param in self._global_blocked_params
)
)
global_block_info_list = self._construct_global_block_info_list(
group_source_ranks=tuple(
group_source_rank for _, group_source_rank in buffer_size_ranks
)
)
# Initialize selectors and local blocked (masked) parameters.
self._distributor_selector: tuple[bool, ...] = tuple(
block_info.group_source_rank == comms_group_rank
for block_info in global_block_info_list
)
self._local_blocked_params: tuple[Tensor, ...] = compress_list(
self._global_blocked_params, self._distributor_selector
)
self._local_masked_blocked_params: tuple[Tensor, ...] = (
self._local_blocked_params
)
self._local_grad_selector: tuple[bool, ...] = (True,) * len(
self._local_blocked_params
)
self._local_block_info_list: tuple[DDPBlockInfo, ...] = compress_list(
global_block_info_list, self._distributor_selector
)
self._construct_distributed_buffers(
buffer_size_ranks=buffer_size_ranks,
communication_dtype=communication_dtype,
comms_group_rank=comms_group_rank,
)
# NOTE: Remove this function once PT2 supports all_gather with functional collective
@torch.no_grad()
@torch.compiler.disable
def all_gather_into_tensor(self) -> None:
dist.all_gather_into_tensor(
self._global_dist_buffer,
self._local_dist_buffer,
group=self._comms_dist_group,
)
@torch.no_grad()
def update_params(
self,
masked_blocked_search_directions: tuple[Tensor, ...],
) -> None:
"""Update params stored inside this distributor according to the input search directions argument.
Args:
masked_blocked_search_directions (tuple[Tensor, ...]): Search directions for each local blocked parameter.
See the comment in the parent class for details.
"""
if self._communicate_params:
# Perform your update to your local masked parameters and copy into buffers.
torch._foreach_add_(
self._local_masked_blocked_params,
masked_blocked_search_directions,
)
torch._foreach_copy_(
self._local_masked_dist_blocked_buffers,
self._local_masked_blocked_params,
)
self.all_gather_into_tensor()
# Copy updated blocked params in global_masked_dist_blocked_buffers
# into global_masked_blocked_params.
torch._foreach_copy_(
self._global_masked_blocked_params,
self._global_masked_dist_blocked_buffers,
)
else:
# Search directions multiplied by alpha are distributed.
# Copy the local search directions to the communication buffer.
torch._foreach_copy_(
self._local_masked_dist_blocked_buffers,
masked_blocked_search_directions,
)
self.all_gather_into_tensor()
# Add search directions in global_masked_dist_blocked_buffers
# to global_masked_blocked_params.
torch._foreach_add_(
self._global_masked_blocked_params,
self._global_masked_dist_blocked_buffers,
)
def _distribute_buffer_sizes(
self,
buffer_sizes: tuple[int, ...],
) -> tuple[tuple[int, int], ...]:
"""Distribute given buffer sizes across ranks in a group.
Buffer sizes will be rounded up for memory allocation. Buffers are distributed such that
total buffer sizes of each rank are as even as possible. This is currently performed
using a greedy algorithm. We do not currently consider computational cost
or kernel launching overheads.
Note: A better distribution strategy should try to minimize the delta of buffer sizes
between the most and the least allocated groups.
Args:
buffer_sizes (tuple[int, ...]): Buffer sizes of blocks to be distributed.
Returns:
buffer_size_ranks (tuple[tuple[int, int], ...]): A list of tuples containing the
buffer size for each block and its assigned rank.
Example:
Assuming ALIGNMENT_BYTES = 64, given buffer_sizes = [128, 64, 500, 256], group_size = 2
-> buffer_size_ranks = [(128, 1), (64, 1), (512, 0), (256, 1)]
"""
return distribute_buffer_sizes(buffer_sizes, self._dist_group_size)
def _construct_composable_block_ids(
self,
param_index: int,
block_index: int,
rank: int | None = None,
) -> tuple[int, str]:
"""Construct composable block ids for each parameter.
Args:
param_index (int): Index of the current parameter within self._param_group[PARAMS].
block_index (int): Block index that is accumulated across all parameters within a parameter group.
rank (int | None): Rank of this process group; should be non None in FSDP/HSDP setting. (Default: None)
Returns:
tuple[int, str]: Composable block id tuple containing global block index and local block name.
The latter will be used to identify blocks in the masked tensor.
"""
return (param_index, f"rank_{rank}-block_{block_index}")
@torch.no_grad()
def _construct_global_block_info_list(
self, group_source_ranks: tuple[int, ...]
) -> tuple[DDPBlockInfo, ...]:
"""Construct the global block info list.
This method creates a list of DDPBlockInfo objects, which contain information
about each parameter block, including its composable block IDs, a function to
allocate zero tensors, a method to retrieve tensors, and the group source rank.
Args:
group_source_ranks (tuple[int, ...]): A list of assigned ranks for each block.
Returns:
tuple[DDPBlockInfo, ...]: A tuple of DDPBlockInfo objects for each parameter block.
"""
# Note that for HSDP, we want to get the rank within each sharded group for the block id.
# When using a device mesh, 0 corresponds to the replicated group and 1 corresponds to the sharded group.
sharded_group_rank = self._hsdp_device_mesh.get_local_rank(1)
return tuple(
DDPBlockInfo(
param=param,
composable_block_ids=self._construct_composable_block_ids(
param_index=param_index,
block_index=block_index,
rank=sharded_group_rank,
),
allocate_zeros_tensor=partial(
self._allocate_zeros_distributed_tensor,
group_source_rank=group_source_rank,
),
get_tensor=lambda input_tensor: (
input_tensor.to_local()
if isinstance(input_tensor, dtensor.DTensor)
else input_tensor
),
group_source_rank=group_source_rank,
)
for (
(param_index, param),
(buffer_size_ranks_start, buffer_size_ranks_end),
) in zip(
enumerate(self._param_group[PARAMS]),
generate_pairwise_indices(self._global_num_blocks_per_param),
strict=True,
)
for block_index, group_source_rank in enumerate(
islice(
group_source_ranks, buffer_size_ranks_start, buffer_size_ranks_end
)
)
)
def _merge_and_block_parameters(
self,
) -> None:
"""Split, merge, and block parameters."""
global_blocked_params: list[Tensor] = []
# self._global_num_splits_per_param refers to the total number of splits within each
# flattened parameter (obtained by split tensor block recovery).
# This has the same length as the number of flattened parameters contained in
# self._param_group[PARAMS].
global_num_splits_per_param = []
# self._global_num_blocks_per_split refers to the total number of blocks within each
# split parameter.
# This has the same length as the number of split parameters.
global_num_blocks_per_split_param = []
# self._global_merged_dims_list has the same length as the total number of split tensor
# blocks within all flattened parameters obtained from split tensor block recovery.
global_merged_dims_list = []
for flattened_param in self._param_group[PARAMS]:
# Split flattened parameters into valid tensor blocks of the parameter.
split_params = HSDPDistributor._split_tensor_block_recovery(
flattened_param,
self._param_to_metadata[flattened_param].shape,
self._param_to_metadata[flattened_param].start_idx,
self._param_to_metadata[flattened_param].end_idx,
)
global_num_splits_per_param.append(len(split_params))
for split_param in split_params:
# Obtain blocks for each parameter after merging.
merged_dims = (
merge_small_dims(
split_param.size(), self._param_group[MAX_PRECONDITIONER_DIM]
)
if self._param_group[USE_MERGE_DIMS]
else split_param.size()
)
blocks_within_split_param = multi_dim_split(
split_param.view(merged_dims),
self._param_group[MAX_PRECONDITIONER_DIM],
)
# Generate and extend block info list and extend blocked parameters list.
# Note that the block info list should have the same length as the blocked parameters list.
global_blocked_params.extend(
# Note: We are using tensor.detach() here to explicitly set block_param (a view of the original
# parameter) to requires_grad = False in order to prevent errors with print and PT2 compile.
# Remove this tensor.detach() once https://github.com/pytorch/pytorch/issues/113793 is fixed.
block_param.detach()
for block_param in blocks_within_split_param
)
# Stores the merged dimensions for each parameter and the number of blocks for each param so
# we could use this later for constructing the mask on filtering blocks when grad is None.
global_merged_dims_list.append(merged_dims)
global_num_blocks_per_split_param.append(len(blocks_within_split_param))
# Check that the number of blocks for each parameter equals to the summation of the number of blocks
# from each split parameter.
self._global_num_blocks_per_param = tuple(
sum(global_num_blocks_per_split_param[block_index:next_block_index])
for (block_index, next_block_index) in generate_pairwise_indices(
global_num_splits_per_param
)
)
# Set lists as tuples.
self._global_blocked_params = tuple(global_blocked_params)
self._global_num_splits_per_param = tuple(global_num_splits_per_param)
self._global_num_blocks_per_split_param = tuple(
global_num_blocks_per_split_param
)
self._global_merged_dims_list = tuple(global_merged_dims_list)
@staticmethod
def _split_local_dist_buffers(
buffer_size_ranks: tuple[tuple[int, int], ...],
local_dist_buffers: tuple[torch.Tensor, ...],
) -> tuple[torch.Tensor, ...]:
"""Split distributed buffers for each local rank into views for each assigned block.
Args:
buffer_size_ranks (tuple[tuple[int, int], ...]): A list of tuples containing the
buffer size and an assigned rank for each block.
local_dist_buffers (tuple[torch.Tensor, ...]): A list of local distributed buffers that
correspond to each rank. Each distributed buffer will be split according to the
assigned tensor blocks.
Returns:
splitted_local_dist_buffers (tuple[torch.Tensor, ...]): A list of tuples containing a view of the
local distributed buffer for each tensor block.
Example:
tensor0 = tensor(1024)
tensor1 = tensor(1024)
buffer_size_ranks = [(128, 0), (64, 0), (512, 1), (256, 0)]
local_dist_buffers = [tensor0, tensor1]
-> splitted_local_dist_buffers = [
tensor0's view( 0-128 bytes),
tensor0's view(128-192 bytes),
tensor1's view( 0-512 bytes),
tensor0's view(192-448 bytes),
]
"""
# Create list of lists containing local views of each split tensor for each rank.
split_tensors_list = []
for rank, local_dist_buffer in enumerate(local_dist_buffers):
required_buffer_sizes = [s for s, r in buffer_size_ranks if r == rank]
remainder_size = local_dist_buffer.size(0) - sum(required_buffer_sizes)
assert (
remainder_size >= 0
), f"Local distributed buffer size {local_dist_buffer.size(0)} is "
"not larger than or equal to the sum of buffer sizes {sum(required_buffer_sizes)}!"
split_tensors = torch.split(
local_dist_buffer, required_buffer_sizes + [remainder_size]
)
split_tensors_list.append(split_tensors)
# Obtain ordered buffer ranks containing (view of local buffer, rank).
splitted_local_dist_buffers = []
buffer_indices = [0] * len(
local_dist_buffers
) # index counter for each rank for obtaining right buffer
for _, rank in buffer_size_ranks:
splitted_local_dist_buffers.append(
split_tensors_list[rank][buffer_indices[rank]]
)
buffer_indices[rank] += 1
return tuple(splitted_local_dist_buffers)
def _construct_distributed_buffers(
self,
buffer_size_ranks: tuple[tuple[int, int], ...],
communication_dtype: torch.dtype,
comms_group_rank: int,
) -> None:
"""Construct the distributed buffers for AllGather communications.
Note that this function will construct the distributed buffer for the AllGather
communication. In addition, it massages the distributed buffer to obtain views
of the buffer corresponding to each block assigned to the current rank.
Args:
buffer_size_ranks (tuple[tuple[int, int], ...]): A list of tuples containing the
buffer size and an assigned rank for each block.
communication_dtype (torch.dtype): The data type used for communication.
comms_group_rank (int): The rank of the current group within the comms group.
"""
# Calculate buffer size each rank needs.
local_buffer_sizes = tuple(
sum(buffer_size for buffer_size, rank in buffer_size_ranks if rank == i)
for i in range(self._dist_group_size)
)
# Calculate the whole buffer size and obtain buffers for every rank.
max_buffer_size_sum = max(local_buffer_sizes)
total_buffer_size = max_buffer_size_sum * self._dist_group_size
self._global_dist_buffer = torch.zeros(
total_buffer_size,
dtype=torch.int8,
device=self._global_blocked_params[0].device,
)
local_dist_buffers = torch.split(self._global_dist_buffer, max_buffer_size_sum)
splitted_local_dist_buffers = HSDPDistributor._split_local_dist_buffers(
buffer_size_ranks, local_dist_buffers
)
# Get local buffer for specific group rank.
self._local_dist_buffer = local_dist_buffers[comms_group_rank]
# Obtain the list of buffers corresponding to each block (ignoring padding).
# Note that each buffer is reshaped into the block's shape and viewed in terms
# of the communication data type.
self._global_dist_blocked_buffers = tuple(
buffer.split(blocked_param.numel() * get_dtype_size(communication_dtype))[0]
.view(communication_dtype)
.view(blocked_param.shape)
for buffer, blocked_param in zip(
splitted_local_dist_buffers, self._global_blocked_params, strict=True
)
)
self._local_dist_blocked_buffers = compress_list(
self._global_dist_blocked_buffers, self._distributor_selector
)
self._global_masked_dist_blocked_buffers = self._global_dist_blocked_buffers
self._local_masked_dist_blocked_buffers = self._local_dist_blocked_buffers
def _merge_and_block_gradients(
self,
) -> tuple[Tensor, ...]:
"""Split, merge, and block gradients.
Returns:
local_masked_blocked_grads (tuple[Tensor, ...]): Local gradients with grad not None.
"""
local_masked_blocked_grads: list[Tensor] = []
global_grad_selector = []
for (
flattened_param,
num_blocks,
(block_index, next_block_index),
(split_index, next_split_index),
) in zip(
self._param_group[PARAMS],
self._global_num_blocks_per_param,
generate_pairwise_indices(self._global_num_blocks_per_param),
generate_pairwise_indices(self._global_num_splits_per_param),
strict=True,
):
flattened_grad = flattened_param.grad
param_distributor_selector = self._distributor_selector[
block_index:next_block_index
]
# Update the selector.
global_grad_selector.extend([flattened_grad is not None] * num_blocks)
if flattened_grad is None or not any(param_distributor_selector):
# Skip split_tensor_block_recovery and multi_dim_split if this blocked grad will not be used locally.
continue
# Split flattened gradients into valid tensor blocks of the gradient.
split_grads = HSDPDistributor._split_tensor_block_recovery(
flattened_grad,
self._param_to_metadata[flattened_param].shape,
self._param_to_metadata[flattened_param].start_idx,
self._param_to_metadata[flattened_param].end_idx,
)
# Get the merged dimensions and the number of blocks for each split gradient.
merged_dims_within_flattened_param = self._global_merged_dims_list[
split_index:next_split_index
]
num_blocks_within_split_grads = self._global_num_blocks_per_split_param[
split_index:next_split_index
]
for (
grad,
merged_dims,
(blocks_within_split_index, next_blocks_within_split_index),
) in zip(
split_grads,
merged_dims_within_flattened_param,
generate_pairwise_indices(num_blocks_within_split_grads),
strict=True,
):
# Obtain blocks for each split gradient after merging.
blocks_within_grad = multi_dim_split(
grad.view(merged_dims), self._param_group[MAX_PRECONDITIONER_DIM]
)
# Generate block-to-parameter metadata and extend blocked parameters list.
local_masked_blocked_grads.extend(
compress_list(
blocks_within_grad,
param_distributor_selector[
blocks_within_split_index:next_blocks_within_split_index
],
)
)
# Set global grad selector as tuple.
self._global_grad_selector = tuple(global_grad_selector)
return tuple(local_masked_blocked_grads)
def merge_and_block_gradients(
self,
) -> tuple[Tensor, ...]:
"""Merge and block gradients.
NOTE: This function MUST be called in the step function of the optimizer after the
gradient has been updated.
Returns:
local_masked_blocked_grads (tuple[Tensor, ...]): Local blocked gradients masked with grad existence.
"""
local_masked_blocked_grads = self._merge_and_block_gradients()
if self._previous_global_grad_selector != self._global_grad_selector:
self._previous_global_grad_selector = self._global_grad_selector
# Update _local_grad_selector and _local_masked_blocked_params only when global_grad_selector is changed.
self._local_grad_selector = compress_list(
self._global_grad_selector,
self._distributor_selector,
)
self._local_masked_blocked_params = compress_list(
self._local_blocked_params, self._local_grad_selector
)
# Re-compress DDP-specific tensor lists using the updated selector.
self._global_masked_blocked_params = compress_list(
self._global_blocked_params, self._global_grad_selector
)
self._global_masked_dist_blocked_buffers = compress_list(
self._global_dist_blocked_buffers, self._global_grad_selector
)
self._local_masked_dist_blocked_buffers = compress_list(
self._local_dist_blocked_buffers, self._local_grad_selector
)
return local_masked_blocked_grads
@staticmethod
def _split_tensor_block_recovery(
tensor_shard: Tensor,
original_shape: torch.Size,
start_idx: int,
end_idx: int,
) -> list[Tensor]:
"""Chunks flattened tensor in order to re-construct valid blocks with respect to the original
multi-dimensional tensor shape and parameter boundaries.
Starting from the first dimension, the largest possible slices in each dimension
(with the remaining dimensions on the right retaining the original shape) are split off.
The following is an example of how the function works for a 2-D tensor shard:
Given an original tensor with shape (7, 14) in Fig. 1, we receive a flattened tensor shard from FSDP
corresponding to Fig. 4. Note that this flattened tensor shard corresponds to the shard of the tensor
in Fig. 2. In order to respect the tensor shape, we need to split the tensor into up to three blocks
(as in Fig. 5). This requires splitting the tensor in Fig. 2 (see flattened tensor shard in Fig. 4)
then reshaping each flattened split tensor into its original shape (see reshaped split tensors in Fig.
3 and 6).
______________
| _______| _______ _______
|______| | ______| | ______|_______|
| | -> | | -> | |
| ___| | ___| |______________|
|__________| | |__________| |__________|
|______________|
original tensor tensor_shard split tensors
Fig. 1 Fig. 2 Fig. 3
Flattened original tensor in Fig. 1:
________________________________________________________________
|____________________|_________________________|_________________|
^ tensor_shard ^
start_idx end_idx
Fig. 4
________________________________________________________________
|____________________|______|_______________|__|_________________|
^^^^^^^^^^^^^^^^^^^^^^^^^^^
^^^ denoted the flattened split tensors in Fig. 3.
Fig. 5
Reshaped split tensors (i.e., the tensors in Fig. 3):
_______
|_______| <- left split
______________
| | <- center split
|______________|
__________
|__________| <- right split
Fig. 6
Args:
tensor_shard (Tensor): A shard of the flattened version of original tensor to split.
original_shape (torch.Size): Shape of original tensor that tensor_shard is a slice of.
start_idx (int): Flattened index in the original tensor where tensor starts (inclusive).
end_idx (int): Flattened index in the original tensor where tensor ends (exclusive).
Returns:
split_tensors (list[Tensor]): List of tensors.
"""
if len(tensor_shard.size()) != 1:
raise ValueError(
f"Input tensor is not flat, has shape {tensor_shard.size()=}."
)
def block_within_tensor_shard_recovery(
block_within_tensor_shard: Tensor,
dimension: int,
block_start_idx: int,
block_end_idx: int,
) -> list[Tensor]:
assert (
block_end_idx - block_start_idx == block_within_tensor_shard.numel()
), f"Start/end indices do not match tensor size: {block_start_idx=}, "
f"{block_end_idx=}, {block_within_tensor_shard.numel()=}!"
if block_end_idx == block_start_idx:
return []
# Handle case where shape is one-dimensional.
# Because it reached the last dimension, we can simply return the flattened tensor.
if dimension == len(original_shape) - 1:
return [block_within_tensor_shard]
# Instantiate list of tensor blocks.
center_split_tensor_blocks = []
# Instantiates flattened indices for recursion.
remaining_size = prod(original_shape[dimension + 1 :])
"""
________________________________________________________________
|____________________|______|_______________|__|_________________|
^ ^ ^ ^
block_start_idx | | block_end_idx
| |
center_split_start_idx |
center_split_end_idx
This came from Fig. 4 above.
"""
# Get starting index of the center split of the tensor shard. (See figure above.)
# This is equal to ceil(block_start_idx / remaining_size) * remaining_size.
center_split_start_idx = (
(block_start_idx + remaining_size - 1) // remaining_size
) * remaining_size
# Similarly, get end index of the center split of the tensor shard.
# This is equal to floor(block_end_idx / remaining_size) * remaining_size.
center_split_end_idx = block_end_idx // remaining_size * remaining_size
# Handles largest convex partition in the center.
if center_split_start_idx < center_split_end_idx:
center_split_start_idx_in_block = (
center_split_start_idx - block_start_idx
)
length_of_center_split = center_split_end_idx - center_split_start_idx
new_shape = [-1] + list(original_shape[dimension + 1 :])
# NOTE: We use Tensor.narrow() instead of slicing in order to guarantee
# there is no copy of the tensor.
center_split_tensor_blocks.append(
block_within_tensor_shard.narrow(
0,
center_split_start_idx_in_block,
length_of_center_split,
).view(new_shape)
)
elif center_split_start_idx > center_split_end_idx:
# Recursively call split tensor block recovery on the full
# flattened tensor ignoring the first dimension of the original
# tensor shape.
return block_within_tensor_shard_recovery(
block_within_tensor_shard=block_within_tensor_shard,
dimension=dimension + 1,
block_start_idx=block_start_idx,
block_end_idx=block_end_idx,
)
# Recursively call split tensor block recovery on the left and right
# splits of the flattened tensor.
left_split_start_idx_in_block = 0
left_split_tensor_size = center_split_start_idx - block_start_idx
left_split_tensor_blocks = block_within_tensor_shard_recovery(
block_within_tensor_shard=block_within_tensor_shard.narrow(
0,
start=left_split_start_idx_in_block,
length=left_split_tensor_size,
),
dimension=dimension + 1,
block_start_idx=block_start_idx,
block_end_idx=center_split_start_idx,
)
center_split_end_idx_in_block = center_split_end_idx - block_start_idx
right_split_tensor_size = block_end_idx - center_split_end_idx
right_split_tensor_blocks = block_within_tensor_shard_recovery(
block_within_tensor_shard=block_within_tensor_shard.narrow(
0,
start=center_split_end_idx_in_block,
length=right_split_tensor_size,
),
dimension=dimension + 1,
block_start_idx=center_split_end_idx,
block_end_idx=block_end_idx,
)
return (
left_split_tensor_blocks
+ center_split_tensor_blocks
+ right_split_tensor_blocks
)
return block_within_tensor_shard_recovery(
block_within_tensor_shard=tensor_shard,
dimension=0,
block_start_idx=start_idx,
block_end_idx=end_idx,
)
def _allocate_zeros_distributed_tensor(
self,
size: tuple[int, ...],
dtype: torch.dtype,
device: torch.device,
group_source_rank: int,
) -> torch.Tensor:
"""Instantiates distributed tensor using DTensor.
Args:
size (tuple[int, ...]): Shape of desired tensor.
dtype (torch.dtype): DType of desired tensor.
device (torch.device): Device of desired tensor.
group_source_rank (int): Group rank (with respect to the sharded group of
the 2D submesh) that determines which ranks the DTensor is allocated on.
Returns:
out (Tensor): Desired Tensor.
"""
ranks_in_replicated_group = torch.tensor(
dist.get_process_group_ranks(self._hsdp_device_mesh.get_group(0))
)
device_mesh_2d = get_device_mesh(
device_type=device.type,
mesh=tuple(
tuple(ranks_in_replicated_subgroup)
for ranks_in_replicated_subgroup in ranks_in_replicated_group.view(
-1, self._dist_group_size
).tolist()
),
mesh_dim_names=("replicate", "shard"),
)
# NOTE: We get all submeshes along the "replicate" dimension, then pick out
# the sub-mesh that the optimizer state is assigned to.
#
# For the example above, this would give me submeshes [[3, 27], [11, 35], [19, 43]].
# Note that the group source rank must belong to {0, 1, 2} in this case.
# Suppose the group_source_rank = 1, then this would get the submesh [11, 35].
replicate_submesh = _mesh_resources._get_all_submeshes(
device_mesh_2d, "replicate"
)[group_source_rank]
return dtensor_zeros(
size,
dtype=dtype,
device_mesh=replicate_submesh,
placements=[dtensor.Replicate()],
)