Description
Is your feature request related to a problem? Please describe.
In some learning problems, the correct allreduce of gradients across data-parallel workers is SUM rather than MEAN. For example, when doing contrastive learning for embedding models, it is common to allgather embeddings across all workers during the forward pass, compute a global loss, and then back-propagate, making the correct gradient reduction a SUM aggregation. Both PyTorch DDP and DeepSpeed currently appear to hard-code a 1/world_size
rescaling of gradients into the distributed backwards pass before executing a SUM allreduce, essentially hard-coding the backwards pass allreduce to be a MEAN operation.
To work around this hard-coded behavior, practitioners have to manually re-scale the loss value by world_size
before back-propagating (e.g. as this PyTorch community post suggests). Other workarounds include disabling or manually adjusting gradient clipping in proportion to the number of workers while simultaneously using a scale-invariant optimizer like Adam.
Describe the solution you'd like
I would like DeepSpeed to support for an additional configuration which would simply disable the 1/world_size
scaling, e.g. my config could look like {"zero_optimization": {"stage": 1}, "gradient_allreduce_op": "sum"}
.
Describe alternatives you've considered
I have thought about trying to subclass DeepSpeedZeroOptimizer
to modify its average_tensor
method to not actually average anything but perform a SUM allreduce (I believe that would just involve gating line 1122 on a check that the config is the default value of "MEAN" instead of "SUM"). However, this would only apply to ZeRO stage1 and stage2, and I couldn't figure out a clean way to actually use my own subclass in my training loop.
Additional context
I chatted a bit with @stas00 about this possible feature, who mentioned this might be of interest to @tjruwase
Activity