Sometimes we want to ensure a fixed total_batch_size.
total_batch_size = local_batch_size * replica_world_size * accumulation_steps
Generally, local_batch_size is fixed, and we can adjust accumulation_steps based on replica_world_size to maintain a fixed total_batch_size.
However, sometimes replica_world_size cannot guarantee an integer accumulation_steps. For example, if local_batch_size = 4 and total_batch_size = 16, and replica_world_size is 3, it's difficult to assign a positive integer value to accumulation_steps. We can avoid this problem if we can restrict the replica_world_size to certain specific values. For example, in this case, if expected_replicas = 1, 2, 4, then the corresponding accumulation_steps will be 4,2,1, thus solving the problem.