Skip to content

Introduce a mark_sharding that also shards the backward #8678

Open
@tengyifei

Description

🚀 Feature

PyTorch/XLA xs.mark_sharding is an in-place operation that adds sharding annotation to an XLA tensor. However, gradients to be applied to the tensor are not annotated with sharding annotations.

Motivation

In some cases, GSPMD fails to propagate sharding annotation from the tensor to its gradient. It's useful to shard both tensor and its gradient with the same sharding annotation.

Pitch

We could write a torch.autograd.Function implementation to do this.

Additional context

JAX mark_sharding shards the gradients too.

cc @bhavya01

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions