Your question
According to my understanding, there are two facts:
- the grad op of communication op is still communication(e.g. allreduce's grad is still allreduce, allgather's grad is reduce-scatter)
- without considering the gradient allreduce in data parallel, if no communication op exists in forward, neither nor in backward.
So, why not use a single function g, whose both forward and backward operation are allreduce instead the two conjucation functions;