21
21
from distributed_shampoo .shampoo_types import (
22
22
AdaGradGraftingConfig ,
23
23
CommunicationDType ,
24
+ FullyShardShampooConfig ,
24
25
HybridShardShampooConfig ,
25
26
)
26
27
from distributed_shampoo .tests .shampoo_test_utils import construct_training_problem
29
30
from torch import nn
30
31
from torch .distributed ._composable .fsdp import fully_shard
31
32
from torch .distributed .checkpoint ._nested_dict import flatten_state_dict
32
- from torch .distributed .device_mesh import init_device_mesh
33
+ from torch .distributed .device_mesh import DeviceMesh , init_device_mesh
33
34
from torch .distributed .tensor import DTensor
34
35
from torch .optim .optimizer import ParamsT
35
36
from torch .testing ._internal .common_distributed import skip_if_lt_x_gpu
@@ -52,27 +53,34 @@ def backend(self) -> str:
52
53
@staticmethod
53
54
def _construct_model (
54
55
device : torch .device ,
55
- distributed_config : HybridShardShampooConfig | None ,
56
+ distributed_config : FullyShardShampooConfig | HybridShardShampooConfig | None ,
57
+ device_mesh : DeviceMesh | None ,
56
58
) -> tuple [nn .Module , nn .Module , torch .Tensor , torch .Tensor , bool ]:
57
59
IN_DIM = 16
58
60
data = torch .arange (IN_DIM , dtype = torch .float , device = device )
59
61
data /= torch .norm (data )
60
- # NOTE: We construct the model here specifically in order to ensure that
61
- # FullyShard Shampoo and default Shampoo produce equivalent results.
62
+ # NOTE: We construct the model here specifically in order to ensure that HybridShard
63
+ # Shampoo, and default Shampoo produce equivalent results.
62
64
# This requires us to construct a model such that FullyShard will split the
63
- # parameters such that the preconditioners created between the FullyShard
65
+ # parameters such that the preconditioners created between the HybridShard Shampoo,
64
66
# and default Shampoo are equivalent.
65
- # +----------------+
66
- # | [4, 16] |
67
- # | GPU0 |
68
- # -------------------- +------+
69
- # | [4, 16] | |[4, 4]|
70
- # | GPU1 | | |
71
- # +----------------+ +------+
67
+ #
68
+ # In a (2, 2) mesh, we have the following parameter distribution:
69
+ #
70
+ # +----------------+ +----------------+
71
+ # | [4, 16] | | [4, 16] |
72
+ # | GPU0 | | GPU1 |
73
+ # -------------------- +------+ -------------------- +------+
74
+ # | [4, 16] | |[4, 4]| | [4, 16] | |[4, 4]|
75
+ # | GPU2 | | | | GPU3 | | |
76
+ # +----------------+ +------+ +----------------+ +------+
77
+ #
78
+ # Each FSDP group has the complete model. (GPU0, GPU2) and (GPU1, GPU3) are
79
+ # 2 FDSP groups.
80
+ #
72
81
# For the first linear layer, each GPU has a [4, 16] parameter. The blocked
73
- # parameters are of size [4, 4] and each GPU has four local blocks (eight
74
- # blocks in total). In comparison, with default shampoo, the eight blocks
75
- # are replicated on two GPUs.
82
+ # parameters are of size [4, 4] and each GPU has four local blocks. In comparison,
83
+ # with default shampoo, the eight blocks are replicated on four GPUs.
76
84
# Similarly, the second linear layer has a [1, 8] parameter and is split
77
85
# into two [4] chunks.
78
86
@@ -86,16 +94,15 @@ def _construct_model(
86
94
fill = 0.1 ,
87
95
)
88
96
89
- if uses_hybrid_shard := isinstance (
90
- distributed_config , HybridShardShampooConfig
97
+ if use_fsdp2 := (
98
+ isinstance (
99
+ distributed_config , (HybridShardShampooConfig , FullyShardShampooConfig )
100
+ )
91
101
):
92
102
# Need this to get pass type-checking test.
93
103
assert distributed_config is not None
94
- model = fully_shard (
95
- model ,
96
- mesh = distributed_config .device_mesh ,
97
- )
98
- return model , loss , data , target , uses_hybrid_shard
104
+ model = fully_shard (model , mesh = device_mesh )
105
+ return model , loss , data , target , use_fsdp2
99
106
100
107
@staticmethod
101
108
def _train_model (
@@ -190,7 +197,7 @@ def _test_two_configs(
190
197
191
198
@staticmethod
192
199
def _shampoo_optim_factory (
193
- distributed_config : HybridShardShampooConfig | None ,
200
+ distributed_config : FullyShardShampooConfig | HybridShardShampooConfig | None ,
194
201
) -> Callable [
195
202
[ParamsT ],
196
203
torch .optim .Optimizer ,
@@ -214,7 +221,8 @@ def _shampoo_optim_factory(
214
221
215
222
@staticmethod
216
223
def _model_factory (
217
- distributed_config : HybridShardShampooConfig | None ,
224
+ distributed_config : FullyShardShampooConfig | HybridShardShampooConfig | None ,
225
+ device_mesh : DeviceMesh | None ,
218
226
) -> Callable [
219
227
[torch .device ],
220
228
tuple [
@@ -228,6 +236,7 @@ def _model_factory(
228
236
return partial (
229
237
ShampooHybridShardDistributorTest ._construct_model ,
230
238
distributed_config = distributed_config ,
239
+ device_mesh = device_mesh ,
231
240
)
232
241
233
242
@with_comms
@@ -266,12 +275,66 @@ def test_hybrid_shard_shampoo_against_default_shampoo(self) -> None:
266
275
),
267
276
ShampooHybridShardDistributorTest ._model_factory (
268
277
None ,
278
+ None ,
279
+ ),
280
+ ShampooHybridShardDistributorTest ._shampoo_optim_factory (
281
+ distributed_config = hybrid_shard_config ,
282
+ ),
283
+ ShampooHybridShardDistributorTest ._model_factory (
284
+ hybrid_shard_config ,
285
+ mesh_2d ,
286
+ ),
287
+ device = torch .device ("cuda" ),
288
+ )
289
+
290
+ @with_comms
291
+ @skip_if_lt_x_gpu (4 )
292
+ def test_hybrid_shard_shampoo_config_against_fully_shard_shampoo_config (
293
+ self ,
294
+ ) -> None :
295
+ mesh_2d = init_device_mesh (
296
+ "cuda" , (2 , 2 ), mesh_dim_names = ("replicate" , "shard" )
297
+ )
298
+ for num_trainers_per_group , (
299
+ communication_dtype ,
300
+ communicate_params ,
301
+ ) in product (
302
+ (- 1 , 1 , 2 ),
303
+ (
304
+ (CommunicationDType .DEFAULT , False ),
305
+ (CommunicationDType .DEFAULT , True ),
306
+ (CommunicationDType .FP16 , False ),
307
+ (CommunicationDType .BF16 , False ),
308
+ ),
309
+ ):
310
+ hybrid_shard_config = HybridShardShampooConfig (
311
+ device_mesh = mesh_2d ,
312
+ communication_dtype = communication_dtype ,
313
+ num_trainers_per_group = num_trainers_per_group ,
314
+ communicate_params = communicate_params ,
315
+ )
316
+
317
+ fully_shard_config = FullyShardShampooConfig ()
318
+
319
+ with self .subTest (
320
+ communication_dtype = communication_dtype ,
321
+ num_trainers_per_group = num_trainers_per_group ,
322
+ communicate_params = communicate_params ,
323
+ ):
324
+ ShampooHybridShardDistributorTest ._test_two_configs (
325
+ ShampooHybridShardDistributorTest ._shampoo_optim_factory (
326
+ distributed_config = fully_shard_config ,
327
+ ),
328
+ ShampooHybridShardDistributorTest ._model_factory (
329
+ fully_shard_config ,
330
+ mesh_2d ,
269
331
),
270
332
ShampooHybridShardDistributorTest ._shampoo_optim_factory (
271
333
distributed_config = hybrid_shard_config ,
272
334
),
273
335
ShampooHybridShardDistributorTest ._model_factory (
274
336
hybrid_shard_config ,
337
+ mesh_2d ,
275
338
),
276
339
device = torch .device ("cuda" ),
277
340
)
@@ -286,7 +349,8 @@ def test_hybrid_shard_shampoo_block_index(self) -> None:
286
349
device_mesh = mesh_2d ,
287
350
)
288
351
model_factory = ShampooHybridShardDistributorTest ._model_factory (
289
- hybrid_shard_config
352
+ hybrid_shard_config ,
353
+ device_mesh = mesh_2d ,
290
354
)
291
355
optim_factory = ShampooHybridShardDistributorTest ._shampoo_optim_factory (
292
356
hybrid_shard_config
@@ -336,7 +400,8 @@ def test_number_of_trainers_per_group_out_of_range(self) -> None:
336
400
distributed_config = hybrid_shard_config ,
337
401
),
338
402
model_factory = ShampooHybridShardDistributorTest ._model_factory (
339
- hybrid_shard_config
403
+ hybrid_shard_config ,
404
+ device_mesh = mesh_2d ,
340
405
),
341
406
device = torch .device ("cuda" ),
342
407
)
@@ -362,7 +427,8 @@ def test_dist_is_initialized(self) -> None:
362
427
distributed_config = hybrid_shard_config ,
363
428
),
364
429
model_factory = ShampooHybridShardDistributorTest ._model_factory (
365
- hybrid_shard_config
430
+ hybrid_shard_config ,
431
+ device_mesh = mesh_2d ,
366
432
),
367
433
device = torch .device ("cuda" ),
368
434
)
@@ -394,7 +460,8 @@ def test_incompatible_replicated_group_size_and_num_trainers_per_group(
394
460
distributed_config = hybrid_shard_config ,
395
461
),
396
462
model_factory = ShampooHybridShardDistributorTest ._model_factory (
397
- hybrid_shard_config
463
+ hybrid_shard_config ,
464
+ device_mesh = mesh_2d ,
398
465
),
399
466
device = torch .device ("cuda" ),
400
467
)
0 commit comments