Skip to content

Introduce annotate_custom_sharding binding #9203

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -1682,6 +1682,48 @@ def test_shard_as(self):
self.assertIn(sharding_spec, x_sharding)
self.assertEqual(x_sharding, y_sharding)

@unittest.skipIf(xr.global_runtime_device_count() == 1,
"Multiple devices needed")
def test_annotate_custom_sharding(self):
xt = torch.randn(2, 4, 64, 64).to(xm.xla_device())
sharded_mesh_axis_0 = self.n_devices // 2
sharded_mesh_axis_1 = self.n_devices // sharded_mesh_axis_0

xs.mark_sharding(
xt, self._get_mesh((1, 1, sharded_mesh_axis_0, sharded_mesh_axis_1)),
(0, 1, 2, 3))
original_sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(xt)

# Attempting to reshard the original tensor should result in a failure
with self.assertRaises(RuntimeError):
xs.mark_sharding(xt, self._get_mesh((1, 1, 1, self.n_devices)),
(0, 1, 2, 3))

self.assertEqual(original_sharding_spec,
torch_xla._XLAC._get_xla_sharding_spec(xt))

# Annotate the existing XLAShardedTensor with a custom sharding IR
xs.annotate_custom_sharding(xt, self._get_mesh((1, 1, 1, self.n_devices)),
(0, 1, 2, 3))

custom_sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(xt)

self.assertEqual(custom_sharding_spec,
torch_xla._XLAC._get_xla_sharding_spec(xt))
self.assertNotEqual(custom_sharding_spec, original_sharding_spec)

hlo = torch_xla._XLAC._get_xla_tensors_hlo([xt])
self.assertIn(
f'%p0.1 = f32[2,4,64,64]{{3,2,1,0}} parameter(0), sharding={original_sharding_spec}',
hlo)
self.assertIn(
f'%custom-call.2 = f32[2,4,64,64]{{3,2,1,0}} custom-call(f32[2,4,64,64]{{3,2,1,0}} %p0.1), custom_call_target="Sharding", sharding={custom_sharding_spec}',
hlo)
xm.mark_step()
# Ensure that the resulting sharding spec is preserved
self.assertEqual(custom_sharding_spec,
torch_xla._XLAC._get_xla_sharding_spec(xt))


if __name__ == '__main__':
test = unittest.main()
Expand Down
5 changes: 5 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2240,6 +2240,11 @@ void InitXlaModuleBindings(py::module m) {
[](const at::Tensor& input, xla::OpSharding sharding) {
ShardingUtil::XlaMarkSharding(input, sharding);
});
m.def("_xla_annotate_custom_sharding",
[](const at::Tensor& input, xla::OpSharding sharding) {
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
ShardingUtil::XlaAnnotateCustomSharding(xtensor, sharding);
});
m.def("_mark_manual_sharding",
[](const at::Tensor& input, xla::OpSharding sharding) {
XLA_CHECK(IsNonDeviceDataIR(input))
Expand Down
38 changes: 29 additions & 9 deletions torch_xla/csrc/xla_sharding_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -766,23 +766,24 @@ void ShardingUtil::XlaMarkSharding(const at::Tensor& input,
XLA_CHECK(sharding.type() != xla::OpSharding::UNKNOWN)
<< "Can't explicilty annotate with UNKNOWN sharding type.";
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
XLATensor::ShardingSpecPtr new_sharding_spec =
std::make_shared<XLATensor::ShardingSpec>(
sharding, MakeShapeWithDeviceLayout(
xtensor->shape(), static_cast<XlaDeviceType>(
xtensor->GetDevice().type())));

// For Non DeviceData IR values, we directly attach the sharding spec
// to the xtensor.
// For Non DeviceData IR values, we directly attach the sharding spec to the
// xtensor.
const DeviceData* device_data_node = nullptr;
if (xtensor->CurrentIrValue()) {
device_data_node = DeviceData::Cast(xtensor->CurrentIrValue().node.get());
if (!device_data_node) {
tensor_methods::custom_sharding_(xtensor, new_sharding_spec);
XlaAnnotateCustomSharding(xtensor, sharding);
return;
}
}

XLATensor::ShardingSpecPtr new_sharding_spec =
std::make_shared<XLATensor::ShardingSpec>(
sharding, MakeShapeWithDeviceLayout(
xtensor->shape(), static_cast<XlaDeviceType>(
xtensor->GetDevice().type())));

// For data, we need to deal with the data transfers between
// host and device.
at::Tensor cpu_tensor;
Expand Down Expand Up @@ -820,7 +821,9 @@ void ShardingUtil::XlaMarkSharding(const at::Tensor& input,
device_data_node != nullptr)
<< "Cannot shard tensor. Data does not present on any device.";
std::vector<XLATensorPtr> xla_tensors{xtensor};
cpu_tensor = XLAGraphExecutor::Get()->GetTensors(&xla_tensors)[0];
auto tensors = XLAGraphExecutor::Get()->GetTensors(&xla_tensors);
XLA_CHECK_EQ(tensors.size(), 1);
cpu_tensor = tensors[0];
}
auto xla_data = CreateTensorsData(
std::vector<at::Tensor>{cpu_tensor},
Expand All @@ -833,6 +836,23 @@ void ShardingUtil::XlaMarkSharding(const at::Tensor& input,
XLAGraphExecutor::Get()->RegisterTensor(xtensor->data());
}

void ShardingUtil::XlaAnnotateCustomSharding(const XLATensorPtr& input,
xla::OpSharding sharding) {
TORCH_LAZY_COUNTER("XlaAnnotateCustomSharding", 1);

XLA_CHECK(UseVirtualDevice())
<< "Please enable SPMD via `torch_xla.runtime.use_spmd()`";
XLA_CHECK(sharding.type() != xla::OpSharding::UNKNOWN)
<< "Can't explicilty annotate with UNKNOWN sharding type.";
Comment on lines +843 to +846
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we need more checks, here? What happens if we call annotate_custom_sharding before mark_sharding? Is this supposed to work?

Copy link
Collaborator Author

@rpsilva-aws rpsilva-aws May 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it should work - but the intended behavior here could be ambiguous. At that point, mark_sharding would similarly just be adding a custom sharding op, since the provided tensor has an IR custom sharding value/node - not a device data node. If we want the mark_sharding following it to identify the parent Device Data node IR and invoke a sharded data transfer - that is something we can discuss/evaluate.

The problem is that, in most cases, mark_sharding is intended to be used as expected for any Device Data node, since it'll have an async runtime buffer alloc over the sharded segments only. In this case, each device would have its own sharded data, and that'll be reflected on the HLO inputs. But what if users want to explicitly have an XLA IR custom sharding over a device data node - that including a replicated tensor as an input to the graph, or simply additional annotations that relies on the XLA compiler to add the optimal CC ops to accommodate the hint.


XLATensor::ShardingSpecPtr sharding_spec =
std::make_shared<XLATensor::ShardingSpec>(
sharding, MakeShapeWithDeviceLayout(
input->shape(),
static_cast<XlaDeviceType>(input->GetDevice().type())));
tensor_methods::custom_sharding_(input, sharding_spec);
}

void ShardingUtil::SetAutoSharding() {
// This stays on throughout the program.
use_auto_sharding = true;
Expand Down
6 changes: 6 additions & 0 deletions torch_xla/csrc/xla_sharding_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,12 @@ class ShardingUtil {
static void XlaMarkSharding(const at::Tensor& input,
xla::OpSharding sharding);

// Add a custom sharding node IR to an XLATensor. Note that unlike
// XlaMarkSharding, this will not explicitly set a sharding spec tied to the
// DeviceData node, nor transfer any sharded data to the device. This serves
// merely as an XLA custom sharding annotation IR.
static void XlaAnnotateCustomSharding(const XLATensorPtr& input,
xla::OpSharding sharding);
//////////////////////////// Auto-Sharding ////////////////////////////

// Construct a device mesh for auto-sharding pass. Returns a tuple of mesh
Expand Down
4 changes: 3 additions & 1 deletion torch_xla/distributed/spmd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
mark_sharding, mark_sharding_with_gradients, clear_sharding, get_1d_mesh,
wrap_if_sharded, xla_patched_nn_linear_forward, set_global_mesh,
get_global_mesh, _mark_manual_sharding, enable_manual_sharding,
disable_manual_sharding, apply_backward_optimization_barrier, shard_as)
disable_manual_sharding, apply_backward_optimization_barrier, shard_as,
annotate_custom_sharding)
from .api import xla_distribute_tensor, xla_distribute_module, auto_policy

__all__ = [
Expand All @@ -20,6 +21,7 @@
"mark_sharding",
"mark_sharding_with_gradients",
"shard_as",
"annotate_custom_sharding",
"clear_sharding",
"get_1d_mesh",
"wrap_if_sharded",
Expand Down
34 changes: 34 additions & 0 deletions torch_xla/distributed/spmd/xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,40 @@ def disable_manual_sharding(t: Union[torch.Tensor, XLAShardedTensor],
return wrap_as_sharded_tensor(t)


def annotate_custom_sharding(t: Union[torch.Tensor,
XLAShardedTensor], mesh: Mesh,
partition_spec: PartitionSpec) -> XLAShardedTensor:
"""
Annotates an existing tensor with a custom sharding IR node without modifying its data layout.

Unlike `mark_sharding`, this function only adds a custom sharding annotation to the XLA IR
without explicitly setting a sharding spec tied to the DeviceData node or transferring any
sharded data to the device. This allows providing explicit XLA sharding annotations of tensors
that have already been sharded with `mark_sharding`.

Args:
t: The input tensor to be annotated with custom sharding.
mesh: The device mesh that specifies the logical device topology.
partition_spec: The partitioning specification for each dimension of the input tensor.

Returns:
XLAShardedTensor: The input tensor wrapped as a sharded tensor with the custom sharding annotation.

Example:
>>> # First shard the tensor with mark_sharding
>>> sharded_tensor = xs.mark_sharding(tensor, mesh1, (0, 1, 2, 3))
>>> # Later, annotate with a different sharding for the XLA SPMD partitioner
>>> custom_sharded = xs.annotate_custom_sharding(sharded_tensor, mesh2, (0, 1, 2, 3))
"""
assert len(t.shape) == len(partition_spec), \
f"Partition spec length ({len(partition_spec)}) should be equal to the input rank ({len(t.shape)})."

op_sharding = mesh.get_op_sharding(partition_spec)
annotate_func = torch_xla._XLAC._xla_annotate_custom_sharding
annotate_func(unwrap_sharded_tensor(t), op_sharding)
return wrap_as_sharded_tensor(t)


def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh,
partition_spec: PartitionSpec) -> XLAShardedTensor:
"""
Expand Down