Open
Description
🚀 Feature
PyTorch/XLA xs.mark_sharding
is an in-place operation that adds sharding annotation to an XLA tensor. However, gradients to be applied to the tensor are not annotated with sharding annotations.
Motivation
In some cases, GSPMD fails to propagate sharding annotation from the tensor to its gradient. It's useful to shard both tensor and its gradient with the same sharding annotation.
Pitch
We could write a torch.autograd.Function
implementation to do this.
Additional context
JAX mark_sharding
shards the gradients too.
cc @bhavya01