Skip to content

Commit 3068552

Browse files
tsunghsienleefacebook-github-bot
authored andcommitted
Open-sourced update on 03/24/2024
Summary: Add new test for comparing HSDP2 vs. FSDP2. Reviewed By: chuanhaozhuge Differential Revision: D71750579 fbshipit-source-id: 1de540961f55c7e49215d8162e0a39992dad9c80
1 parent 65b0d32 commit 3068552

File tree

2 files changed

+98
-28
lines changed

2 files changed

+98
-28
lines changed

distributed_shampoo/utils/gpu_tests/shampoo_hybrid_shard_distributor_test.py

+95-28
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from distributed_shampoo.shampoo_types import (
2222
AdaGradGraftingConfig,
2323
CommunicationDType,
24+
FullyShardShampooConfig,
2425
HybridShardShampooConfig,
2526
)
2627
from distributed_shampoo.tests.shampoo_test_utils import construct_training_problem
@@ -29,7 +30,7 @@
2930
from torch import nn
3031
from torch.distributed._composable.fsdp import fully_shard
3132
from torch.distributed.checkpoint._nested_dict import flatten_state_dict
32-
from torch.distributed.device_mesh import init_device_mesh
33+
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
3334
from torch.distributed.tensor import DTensor
3435
from torch.optim.optimizer import ParamsT
3536
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
@@ -52,27 +53,34 @@ def backend(self) -> str:
5253
@staticmethod
5354
def _construct_model(
5455
device: torch.device,
55-
distributed_config: HybridShardShampooConfig | None,
56+
distributed_config: FullyShardShampooConfig | HybridShardShampooConfig | None,
57+
device_mesh: DeviceMesh | None,
5658
) -> tuple[nn.Module, nn.Module, torch.Tensor, torch.Tensor, bool]:
5759
IN_DIM = 16
5860
data = torch.arange(IN_DIM, dtype=torch.float, device=device)
5961
data /= torch.norm(data)
60-
# NOTE: We construct the model here specifically in order to ensure that
61-
# FullyShard Shampoo and default Shampoo produce equivalent results.
62+
# NOTE: We construct the model here specifically in order to ensure that HybridShard
63+
# Shampoo, and default Shampoo produce equivalent results.
6264
# This requires us to construct a model such that FullyShard will split the
63-
# parameters such that the preconditioners created between the FullyShard
65+
# parameters such that the preconditioners created between the HybridShard Shampoo,
6466
# and default Shampoo are equivalent.
65-
# +----------------+
66-
# | [4, 16] |
67-
# | GPU0 |
68-
# -------------------- +------+
69-
# | [4, 16] | |[4, 4]|
70-
# | GPU1 | | |
71-
# +----------------+ +------+
67+
#
68+
# In a (2, 2) mesh, we have the following parameter distribution:
69+
#
70+
# +----------------+ +----------------+
71+
# | [4, 16] | | [4, 16] |
72+
# | GPU0 | | GPU1 |
73+
# -------------------- +------+ -------------------- +------+
74+
# | [4, 16] | |[4, 4]| | [4, 16] | |[4, 4]|
75+
# | GPU2 | | | | GPU3 | | |
76+
# +----------------+ +------+ +----------------+ +------+
77+
#
78+
# Each FSDP group has the complete model. (GPU0, GPU2) and (GPU1, GPU3) are
79+
# 2 FDSP groups.
80+
#
7281
# For the first linear layer, each GPU has a [4, 16] parameter. The blocked
73-
# parameters are of size [4, 4] and each GPU has four local blocks (eight
74-
# blocks in total). In comparison, with default shampoo, the eight blocks
75-
# are replicated on two GPUs.
82+
# parameters are of size [4, 4] and each GPU has four local blocks. In comparison,
83+
# with default shampoo, the eight blocks are replicated on four GPUs.
7684
# Similarly, the second linear layer has a [1, 8] parameter and is split
7785
# into two [4] chunks.
7886

@@ -86,16 +94,15 @@ def _construct_model(
8694
fill=0.1,
8795
)
8896

89-
if uses_hybrid_shard := isinstance(
90-
distributed_config, HybridShardShampooConfig
97+
if use_fsdp2 := (
98+
isinstance(
99+
distributed_config, (HybridShardShampooConfig, FullyShardShampooConfig)
100+
)
91101
):
92102
# Need this to get pass type-checking test.
93103
assert distributed_config is not None
94-
model = fully_shard(
95-
model,
96-
mesh=distributed_config.device_mesh,
97-
)
98-
return model, loss, data, target, uses_hybrid_shard
104+
model = fully_shard(model, mesh=device_mesh)
105+
return model, loss, data, target, use_fsdp2
99106

100107
@staticmethod
101108
def _train_model(
@@ -190,7 +197,7 @@ def _test_two_configs(
190197

191198
@staticmethod
192199
def _shampoo_optim_factory(
193-
distributed_config: HybridShardShampooConfig | None,
200+
distributed_config: FullyShardShampooConfig | HybridShardShampooConfig | None,
194201
) -> Callable[
195202
[ParamsT],
196203
torch.optim.Optimizer,
@@ -214,7 +221,8 @@ def _shampoo_optim_factory(
214221

215222
@staticmethod
216223
def _model_factory(
217-
distributed_config: HybridShardShampooConfig | None,
224+
distributed_config: FullyShardShampooConfig | HybridShardShampooConfig | None,
225+
device_mesh: DeviceMesh | None,
218226
) -> Callable[
219227
[torch.device],
220228
tuple[
@@ -228,6 +236,7 @@ def _model_factory(
228236
return partial(
229237
ShampooHybridShardDistributorTest._construct_model,
230238
distributed_config=distributed_config,
239+
device_mesh=device_mesh,
231240
)
232241

233242
@with_comms
@@ -266,12 +275,66 @@ def test_hybrid_shard_shampoo_against_default_shampoo(self) -> None:
266275
),
267276
ShampooHybridShardDistributorTest._model_factory(
268277
None,
278+
None,
279+
),
280+
ShampooHybridShardDistributorTest._shampoo_optim_factory(
281+
distributed_config=hybrid_shard_config,
282+
),
283+
ShampooHybridShardDistributorTest._model_factory(
284+
hybrid_shard_config,
285+
mesh_2d,
286+
),
287+
device=torch.device("cuda"),
288+
)
289+
290+
@with_comms
291+
@skip_if_lt_x_gpu(4)
292+
def test_hybrid_shard_shampoo_config_against_fully_shard_shampoo_config(
293+
self,
294+
) -> None:
295+
mesh_2d = init_device_mesh(
296+
"cuda", (2, 2), mesh_dim_names=("replicate", "shard")
297+
)
298+
for num_trainers_per_group, (
299+
communication_dtype,
300+
communicate_params,
301+
) in product(
302+
(-1, 1, 2),
303+
(
304+
(CommunicationDType.DEFAULT, False),
305+
(CommunicationDType.DEFAULT, True),
306+
(CommunicationDType.FP16, False),
307+
(CommunicationDType.BF16, False),
308+
),
309+
):
310+
hybrid_shard_config = HybridShardShampooConfig(
311+
device_mesh=mesh_2d,
312+
communication_dtype=communication_dtype,
313+
num_trainers_per_group=num_trainers_per_group,
314+
communicate_params=communicate_params,
315+
)
316+
317+
fully_shard_config = FullyShardShampooConfig()
318+
319+
with self.subTest(
320+
communication_dtype=communication_dtype,
321+
num_trainers_per_group=num_trainers_per_group,
322+
communicate_params=communicate_params,
323+
):
324+
ShampooHybridShardDistributorTest._test_two_configs(
325+
ShampooHybridShardDistributorTest._shampoo_optim_factory(
326+
distributed_config=fully_shard_config,
327+
),
328+
ShampooHybridShardDistributorTest._model_factory(
329+
fully_shard_config,
330+
mesh_2d,
269331
),
270332
ShampooHybridShardDistributorTest._shampoo_optim_factory(
271333
distributed_config=hybrid_shard_config,
272334
),
273335
ShampooHybridShardDistributorTest._model_factory(
274336
hybrid_shard_config,
337+
mesh_2d,
275338
),
276339
device=torch.device("cuda"),
277340
)
@@ -286,7 +349,8 @@ def test_hybrid_shard_shampoo_block_index(self) -> None:
286349
device_mesh=mesh_2d,
287350
)
288351
model_factory = ShampooHybridShardDistributorTest._model_factory(
289-
hybrid_shard_config
352+
hybrid_shard_config,
353+
device_mesh=mesh_2d,
290354
)
291355
optim_factory = ShampooHybridShardDistributorTest._shampoo_optim_factory(
292356
hybrid_shard_config
@@ -336,7 +400,8 @@ def test_number_of_trainers_per_group_out_of_range(self) -> None:
336400
distributed_config=hybrid_shard_config,
337401
),
338402
model_factory=ShampooHybridShardDistributorTest._model_factory(
339-
hybrid_shard_config
403+
hybrid_shard_config,
404+
device_mesh=mesh_2d,
340405
),
341406
device=torch.device("cuda"),
342407
)
@@ -362,7 +427,8 @@ def test_dist_is_initialized(self) -> None:
362427
distributed_config=hybrid_shard_config,
363428
),
364429
model_factory=ShampooHybridShardDistributorTest._model_factory(
365-
hybrid_shard_config
430+
hybrid_shard_config,
431+
device_mesh=mesh_2d,
366432
),
367433
device=torch.device("cuda"),
368434
)
@@ -394,7 +460,8 @@ def test_incompatible_replicated_group_size_and_num_trainers_per_group(
394460
distributed_config=hybrid_shard_config,
395461
),
396462
model_factory=ShampooHybridShardDistributorTest._model_factory(
397-
hybrid_shard_config
463+
hybrid_shard_config,
464+
device_mesh=mesh_2d,
398465
),
399466
device=torch.device("cuda"),
400467
)

distributed_shampoo/utils/shampoo_preconditioner_list.py

+3
Original file line numberDiff line numberDiff line change
@@ -936,6 +936,9 @@ def _precondition_grad(
936936
) -> Tensor:
937937
# TODO: Need to refactor this function to be more efficient. Ideally eliminate those branches.
938938
# Might consider einsum?
939+
assert (
940+
sum(preconditioned_dims_selector) == len(preconditioner_list)
941+
), f"The number of dimensions to precondition ({sum(preconditioned_dims_selector)}) must match the number of preconditioners ({len(preconditioner_list)})."
939942
preconditioner_list_iter = iter(preconditioner_list)
940943
return reduce(
941944
lambda grad, should_precondition: torch.tensordot(

0 commit comments

Comments
 (0)