@@ -406,9 +406,8 @@ class TrainArgs:
406406 """
407407 Modified training-related arguments in :class:`litgpt.args.TrainArgs`.
408408
409- Here, `global_batch_size` does not have a default value. If not given,
410- it should be set to the product of `micro_batch_size` and the number of
411- devices, unless sequential gradient averaging is desired.
409+ `global_batch_size` is a legacy argument, which must be equal to the
410+ product of `micro_batch_size` and the number of devices, if given.
412411
413412 Storing intermediate checkpoints: Normal checkpoints are stored whenever
414413 `state["step_count"] % train.save_interval == 0`. If
@@ -437,7 +436,7 @@ class TrainArgs:
437436 log_interval : int = 1
438437 """Number of iterations between logging calls"""
439438 global_batch_size : Optional [int ] = None
440- """Number of samples between optimizer steps across data-parallel ranks """
439+ """Legacy argument: Do not use """
441440 micro_batch_size : int = 4
442441 """Number of samples per data-parallel rank"""
443442 lr_warmup_steps : Optional [int ] = 100
@@ -506,23 +505,6 @@ def __post_init__(self) -> None:
506505 if self .max_grad_norm is not None and self .max_grad_norm <= 0 :
507506 raise ValueError ("max_grad_norm must be positive (or `None` to disable)" )
508507
509- def gradient_accumulation_iters (self , devices : int , num_nodes : int = 1 ) -> int :
510- """Number of iterations between gradient synchronizations"""
511- gradient_accumulation_iters = (
512- self .batch_size (devices , num_nodes ) // self .micro_batch_size
513- )
514- assert gradient_accumulation_iters > 0
515- return gradient_accumulation_iters
516-
517- def batch_size (self , devices : int , num_nodes : int = 1 ) -> int :
518- """Number of samples between optimizer steps per data-parallel rank"""
519- if self .global_batch_size is None :
520- batch_size = self .micro_batch_size
521- else :
522- batch_size = self .global_batch_size // (devices * num_nodes )
523- assert batch_size > 0
524- return batch_size
525-
526508 def warmup_iters (
527509 self , devices : int , num_nodes : int , max_iters : int , train_dataloader
528510 ) -> int :
@@ -532,11 +514,7 @@ def warmup_iters(
532514 max_iters , math .ceil (self .lr_warmup_fraction * len (train_dataloader ))
533515 )
534516 if self .lr_warmup_steps :
535- return min (
536- max_iters ,
537- self .lr_warmup_steps
538- * self .gradient_accumulation_iters (devices , num_nodes ),
539- )
517+ return min (max_iters , self .lr_warmup_steps )
540518 return 0
541519
542520
0 commit comments