Skip to content

Commit 8e25111

Browse files
tsunghsienleefacebook-github-bot
authored andcommitted
Open-sourced update on 03/25/2024 (facebookresearch#97)
Summary: Pull Request resolved: facebookresearch#97 1. Add `ShampooPreconditionerConfig.inverse_exponent_override` to merge `inv_root_override`, `ShampooPreconditionerConfig.ignored_dims`, and `ShampooPreconditionerConfig.amortized_computation_config.exponent_multiplier` in Shampoo. 2. Add `EigenvalueCorrectedShampooPreconditionerConfig.inverse_exponent_override` to replace `inv_root_override`, and push `PreconditionerConfig.ignored_dims` to `EigenvalueCorrectedShampooPreconditionerConfig. ignored_basis_change_dims` in eigenvalue-corrected Shampoo. Reviewed By: anana10c Differential Revision: D71823575 fbshipit-source-id: 14743fdddebb78a89005af01a61112fdbca9a3ef
1 parent 9ed1df9 commit 8e25111

14 files changed

+525
-271
lines changed

distributed_shampoo/distributed_shampoo.py

+11-24
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
"""
99

1010
import logging
11-
from collections.abc import Callable, Iterator, Sequence
11+
from collections.abc import Callable, Iterator
1212
from copy import deepcopy
1313
from functools import partial
1414
from typing import Any
@@ -36,7 +36,6 @@
3636
GraftingConfig,
3737
HSDPShampooConfig,
3838
HybridShardShampooConfig,
39-
INV_ROOT_OVERRIDE,
4039
LR,
4140
MASKED_BLOCKED_GRADS,
4241
MASKED_BLOCKED_PARAMS,
@@ -265,12 +264,6 @@ class DistributedShampoo(torch.optim.Optimizer):
265264
(Default: 1)
266265
start_preconditioning_step (int): Iteration to start computing inverse preconditioner. If -1, uses
267266
the same value as precondition_frequency. (Default: -1)
268-
inv_root_override (int, Sequence[int]): Inverse root to use in Shampoo. If a list [l1, l2, ..., lp], then we will
269-
use -1 / l1 for 1-D tensor (vectors), -1 / l2 for 2-D tensors (matrices), and so on. If the order of the
270-
tensor exceeds the order of the tensor, reverts to the default value. If 0 is used, uses the default inverse
271-
root -1 / (2 * o), where o is the order of the tensor. If preconditioner_config is an instance of
272-
EigenvalueCorrectedShampooPreconditionerConfig, the default is -1 / 2.
273-
(Default: 0)
274267
use_nesterov (bool): Flag for using Nesterov momentum. (default: False)
275268
use_bias_correction (bool): Flag for using bias correction. (Default: True)
276269
use_decoupled_weight_decay (bool): Flag for using AdamW-style decoupled weight decay. (Default: True)
@@ -303,7 +296,6 @@ def __init__(
303296
max_preconditioner_dim: int = 1024,
304297
precondition_frequency: int = 1,
305298
start_preconditioning_step: int = -1,
306-
inv_root_override: int | Sequence[int] = 0,
307299
use_nesterov: bool = False,
308300
use_bias_correction: bool = True,
309301
use_decoupled_weight_decay: bool = True,
@@ -357,16 +349,6 @@ def __init__(
357349
raise ValueError(
358350
f"Invalid start preconditioning step: {start_preconditioning_step}. Must be >= -1."
359351
)
360-
if isinstance(inv_root_override, Sequence):
361-
if not all(e >= 0 for e in inv_root_override):
362-
raise ValueError(
363-
f"Invalid exponent override list: {inv_root_override}. All values must be >= 0."
364-
)
365-
else:
366-
if not inv_root_override >= 0:
367-
raise ValueError(
368-
f"Invalid exponent override: {inv_root_override}. Must be >= 0."
369-
)
370352

371353
# Provide warning/error for start_preconditioning_step.
372354
if start_preconditioning_step == -1:
@@ -387,10 +369,17 @@ def __init__(
387369
"Continuing without using momentum or Nesterov acceleration..."
388370
)
389371

390-
# Check potential conflict between preconditioner_config.ignored_dims and inv_root_override.
391-
if preconditioner_config.ignored_dims != [] and inv_root_override != 0:
372+
# No use of preconditioner_config.amortized_computation_config.exponent_multiplier.
373+
if (
374+
getattr(
375+
preconditioner_config.amortized_computation_config,
376+
"exponent_multiplier",
377+
1.0,
378+
)
379+
!= 1.0
380+
):
392381
raise ValueError(
393-
f"{preconditioner_config.ignored_dims=} is not supported when {inv_root_override=} is not set to 0. Please set {inv_root_override=} to 0 if you set {preconditioner_config.ignored_dims=}."
382+
"preconditioner_config.amortized_computation_config.exponent_multiplier is not supported. Please use PreconditionerConfig.inverse_exponent_override instead."
394383
)
395384

396385
super().__init__(
@@ -406,7 +395,6 @@ def __init__(
406395
MAX_PRECONDITIONER_DIM: max_preconditioner_dim,
407396
PRECONDITION_FREQUENCY: precondition_frequency,
408397
START_PRECONDITIONING_STEP: start_preconditioning_step,
409-
INV_ROOT_OVERRIDE: inv_root_override,
410398
USE_NESTEROV: use_nesterov,
411399
USE_BIAS_CORRECTION: use_bias_correction,
412400
USE_DECOUPLED_WEIGHT_DECAY: use_decoupled_weight_decay,
@@ -516,7 +504,6 @@ def _instantiate_shampoo_preconditioner_list(self) -> None:
516504
preconditioner_config=group[PRECONDITIONER_CONFIG],
517505
beta2=group[BETAS][1],
518506
epsilon=group[EPSILON],
519-
inv_root_override=group[INV_ROOT_OVERRIDE],
520507
use_bias_correction=group[USE_BIAS_CORRECTION],
521508
factor_matrix_dtype=group[PRECONDITIONER_DTYPE],
522509
)

distributed_shampoo/examples/ddp_cifar10_example.py

-1
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,6 @@
108108
max_preconditioner_dim=args.max_preconditioner_dim,
109109
precondition_frequency=args.precondition_frequency,
110110
start_preconditioning_step=args.start_preconditioning_step,
111-
inv_root_override=args.inv_root_override,
112111
exponent_multiplier=args.exponent_multiplier,
113112
use_nesterov=args.use_nesterov,
114113
use_bias_correction=args.use_bias_correction,

distributed_shampoo/examples/default_cifar10_example.py

-1
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,6 @@ def train_default_model(
122122
max_preconditioner_dim=args.max_preconditioner_dim,
123123
precondition_frequency=args.precondition_frequency,
124124
start_preconditioning_step=args.start_preconditioning_step,
125-
inv_root_override=args.inv_root_override,
126125
exponent_multiplier=args.exponent_multiplier,
127126
use_nesterov=args.use_nesterov,
128127
use_bias_correction=args.use_bias_correction,

distributed_shampoo/examples/fsdp_cifar10_example.py

-1
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,6 @@
102102
max_preconditioner_dim=args.max_preconditioner_dim,
103103
precondition_frequency=args.precondition_frequency,
104104
start_preconditioning_step=args.start_preconditioning_step,
105-
inv_root_override=args.inv_root_override,
106105
exponent_multiplier=args.exponent_multiplier,
107106
use_nesterov=args.use_nesterov,
108107
use_bias_correction=args.use_bias_correction,

distributed_shampoo/examples/fully_shard_cifar10_example.py

-1
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,6 @@ def create_model_and_optimizer_and_loss_fn(args, device):
124124
max_preconditioner_dim=args.max_preconditioner_dim,
125125
precondition_frequency=args.precondition_frequency,
126126
start_preconditioning_step=args.start_preconditioning_step,
127-
inv_root_override=args.inv_root_override,
128127
exponent_multiplier=args.exponent_multiplier,
129128
use_nesterov=args.use_nesterov,
130129
use_bias_correction=args.use_bias_correction,

distributed_shampoo/examples/hsdp_cifar10_example.py

-1
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,6 @@
115115
max_preconditioner_dim=args.max_preconditioner_dim,
116116
precondition_frequency=args.precondition_frequency,
117117
start_preconditioning_step=args.start_preconditioning_step,
118-
inv_root_override=args.inv_root_override,
119118
exponent_multiplier=args.exponent_multiplier,
120119
use_nesterov=args.use_nesterov,
121120
use_bias_correction=args.use_bias_correction,

distributed_shampoo/examples/trainer_utils.py

-2
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,6 @@ def instantiate_optimizer(
371371
max_preconditioner_dim: int,
372372
precondition_frequency: int,
373373
start_preconditioning_step: int,
374-
inv_root_override: int,
375374
exponent_multiplier: float,
376375
use_nesterov: bool,
377376
use_bias_correction: bool,
@@ -423,7 +422,6 @@ def instantiate_optimizer(
423422
max_preconditioner_dim=max_preconditioner_dim,
424423
precondition_frequency=precondition_frequency,
425424
start_preconditioning_step=start_preconditioning_step,
426-
inv_root_override=inv_root_override,
427425
use_nesterov=use_nesterov,
428426
use_bias_correction=use_bias_correction,
429427
use_decoupled_weight_decay=use_decoupled_weight_decay,

distributed_shampoo/gpu_tests/shampoo_eigenvalue_correction_test.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,9 @@ def test_adagrad_eigenvalue_correction_on_quadratic(self) -> None:
6363
(math.inf, DefaultSOAPConfig),
6464
(
6565
1,
66-
EigenvalueCorrectedShampooPreconditionerConfig(ignored_dims=[0, 1]),
66+
EigenvalueCorrectedShampooPreconditionerConfig(
67+
ignored_basis_change_dims={0: [0], 1: [0], 2: [0, 1]}
68+
),
6769
),
6870
),
6971
):
@@ -113,7 +115,9 @@ def test_adam_eigenvalue_correction_on_quadratic(self) -> None:
113115
(math.inf, DefaultSOAPConfig),
114116
(
115117
1,
116-
EigenvalueCorrectedShampooPreconditionerConfig(ignored_dims=[0, 1]),
118+
EigenvalueCorrectedShampooPreconditionerConfig(
119+
ignored_basis_change_dims={0: [0], 1: [0], 2: [0, 1]}
120+
),
117121
),
118122
),
119123
):
@@ -165,7 +169,9 @@ def test_adamw_eigenvalue_correction_on_quadratic(self) -> None:
165169
(math.inf, DefaultSOAPConfig),
166170
(
167171
1,
168-
EigenvalueCorrectedShampooPreconditionerConfig(ignored_dims=[0, 1]),
172+
EigenvalueCorrectedShampooPreconditionerConfig(
173+
ignored_basis_change_dims={0: [0], 1: [0], 2: [0, 1]}
174+
),
169175
),
170176
),
171177
):
@@ -217,7 +223,9 @@ def test_rmsprop_eigenvalue_correction_on_quadratic(self) -> None:
217223
(math.inf, DefaultSOAPConfig),
218224
(
219225
1,
220-
EigenvalueCorrectedShampooPreconditionerConfig(ignored_dims=[0, 1]),
226+
EigenvalueCorrectedShampooPreconditionerConfig(
227+
ignored_basis_change_dims={0: [0], 1: [0], 2: [0, 1]}
228+
),
221229
),
222230
),
223231
):

0 commit comments

Comments
 (0)