8
8
"""
9
9
10
10
import logging
11
- from collections .abc import Callable , Iterator , Sequence
11
+ from collections .abc import Callable , Iterator
12
12
from copy import deepcopy
13
13
from functools import partial
14
14
from typing import Any
36
36
GraftingConfig ,
37
37
HSDPShampooConfig ,
38
38
HybridShardShampooConfig ,
39
- INV_ROOT_OVERRIDE ,
40
39
LR ,
41
40
MASKED_BLOCKED_GRADS ,
42
41
MASKED_BLOCKED_PARAMS ,
@@ -265,12 +264,6 @@ class DistributedShampoo(torch.optim.Optimizer):
265
264
(Default: 1)
266
265
start_preconditioning_step (int): Iteration to start computing inverse preconditioner. If -1, uses
267
266
the same value as precondition_frequency. (Default: -1)
268
- inv_root_override (int, Sequence[int]): Inverse root to use in Shampoo. If a list [l1, l2, ..., lp], then we will
269
- use -1 / l1 for 1-D tensor (vectors), -1 / l2 for 2-D tensors (matrices), and so on. If the order of the
270
- tensor exceeds the order of the tensor, reverts to the default value. If 0 is used, uses the default inverse
271
- root -1 / (2 * o), where o is the order of the tensor. If preconditioner_config is an instance of
272
- EigenvalueCorrectedShampooPreconditionerConfig, the default is -1 / 2.
273
- (Default: 0)
274
267
use_nesterov (bool): Flag for using Nesterov momentum. (default: False)
275
268
use_bias_correction (bool): Flag for using bias correction. (Default: True)
276
269
use_decoupled_weight_decay (bool): Flag for using AdamW-style decoupled weight decay. (Default: True)
@@ -303,7 +296,6 @@ def __init__(
303
296
max_preconditioner_dim : int = 1024 ,
304
297
precondition_frequency : int = 1 ,
305
298
start_preconditioning_step : int = - 1 ,
306
- inv_root_override : int | Sequence [int ] = 0 ,
307
299
use_nesterov : bool = False ,
308
300
use_bias_correction : bool = True ,
309
301
use_decoupled_weight_decay : bool = True ,
@@ -357,16 +349,6 @@ def __init__(
357
349
raise ValueError (
358
350
f"Invalid start preconditioning step: { start_preconditioning_step } . Must be >= -1."
359
351
)
360
- if isinstance (inv_root_override , Sequence ):
361
- if not all (e >= 0 for e in inv_root_override ):
362
- raise ValueError (
363
- f"Invalid exponent override list: { inv_root_override } . All values must be >= 0."
364
- )
365
- else :
366
- if not inv_root_override >= 0 :
367
- raise ValueError (
368
- f"Invalid exponent override: { inv_root_override } . Must be >= 0."
369
- )
370
352
371
353
# Provide warning/error for start_preconditioning_step.
372
354
if start_preconditioning_step == - 1 :
@@ -387,10 +369,17 @@ def __init__(
387
369
"Continuing without using momentum or Nesterov acceleration..."
388
370
)
389
371
390
- # Check potential conflict between preconditioner_config.ignored_dims and inv_root_override.
391
- if preconditioner_config .ignored_dims != [] and inv_root_override != 0 :
372
+ # No use of preconditioner_config.amortized_computation_config.exponent_multiplier.
373
+ if (
374
+ getattr (
375
+ preconditioner_config .amortized_computation_config ,
376
+ "exponent_multiplier" ,
377
+ 1.0 ,
378
+ )
379
+ != 1.0
380
+ ):
392
381
raise ValueError (
393
- f" { preconditioner_config .ignored_dims = } is not supported when { inv_root_override = } is not set to 0 . Please set { inv_root_override = } to 0 if you set { preconditioner_config . ignored_dims = } ."
382
+ " preconditioner_config.amortized_computation_config.exponent_multiplier is not supported. Please use PreconditionerConfig.inverse_exponent_override instead ."
394
383
)
395
384
396
385
super ().__init__ (
@@ -406,7 +395,6 @@ def __init__(
406
395
MAX_PRECONDITIONER_DIM : max_preconditioner_dim ,
407
396
PRECONDITION_FREQUENCY : precondition_frequency ,
408
397
START_PRECONDITIONING_STEP : start_preconditioning_step ,
409
- INV_ROOT_OVERRIDE : inv_root_override ,
410
398
USE_NESTEROV : use_nesterov ,
411
399
USE_BIAS_CORRECTION : use_bias_correction ,
412
400
USE_DECOUPLED_WEIGHT_DECAY : use_decoupled_weight_decay ,
@@ -516,7 +504,6 @@ def _instantiate_shampoo_preconditioner_list(self) -> None:
516
504
preconditioner_config = group [PRECONDITIONER_CONFIG ],
517
505
beta2 = group [BETAS ][1 ],
518
506
epsilon = group [EPSILON ],
519
- inv_root_override = group [INV_ROOT_OVERRIDE ],
520
507
use_bias_correction = group [USE_BIAS_CORRECTION ],
521
508
factor_matrix_dtype = group [PRECONDITIONER_DTYPE ],
522
509
)
0 commit comments