Conversation
3f34441 to
239ed36
Compare
|
|
||
| # Switch to FP32 shard after backward. | ||
| self._use_fp32_param_shard([param]) | ||
| if self.mixed_precision and self.fp32_reduce_scatter: |
There was a problem hiding this comment.
Currently for fp8, we do not use mixed_precision, so we should remove this.
Only check
if self.fp32_reduce_scatter:
239ed36 to
ad54660
Compare
| # Cast grad to FP32. | ||
| param.grad.data = param.grad.data.float() | ||
|
|
||
| orig_grad_data = param.grad.data |
There was a problem hiding this comment.
Move here to make orig_grad_data FP32. This was from #1139 (comment)
|
|
||
| if self.fp32_reduce_scatter: | ||
| # Cast grad to FP32. | ||
| param.grad.data = param.grad.data.float() |
There was a problem hiding this comment.
I don't feel this is right since param.grad will be None from L1722.
Overall, this PR creates main_grad for flat parameters while what we need to do is main_grad visible to TE modules. So probably we need to change FlatParameter as well?
Is this based on one of Naman's branches?
There was a problem hiding this comment.
I have a branch where i am adding param.main_grad to FlatParams to enable fuse wgrad accumulation. here is the PR : #1142
There was a problem hiding this comment.
Thanks! Feel free to ignore the changes in this PR. Still learning about FlatParams etc.
What does this PR do?
Fixes main_grad following up #1139 (comment)
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.