Skip to content

Commit 32344f3

Browse files
committed
Part 1: Disambiguate custom sharding op for DeviceData IR nodes
1 parent 1f9dd8f commit 32344f3

File tree

1 file changed

+8
-11
lines changed

1 file changed

+8
-11
lines changed

torch_xla/csrc/xla_sharding_util.cpp

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -772,15 +772,11 @@ void ShardingUtil::XlaMarkSharding(const at::Tensor& input,
772772
xtensor->shape(), static_cast<XlaDeviceType>(
773773
xtensor->GetDevice().type())));
774774

775-
// For Non DeviceData IR values, we directly attach the sharding spec
776-
// to the xtensor.
777-
const DeviceData* device_data_node = nullptr;
775+
// For IR values, we directly attach the sharding spec to the xtensor. In case
776+
// it is a DeviceData IR node, we similarly add a custom sharding IR node.
778777
if (xtensor->CurrentIrValue()) {
779-
device_data_node = DeviceData::Cast(xtensor->CurrentIrValue().node.get());
780-
if (!device_data_node) {
781-
tensor_methods::custom_sharding_(xtensor, new_sharding_spec);
782-
return;
783-
}
778+
tensor_methods::custom_sharding_(xtensor, new_sharding_spec);
779+
return;
784780
}
785781

786782
// For data, we need to deal with the data transfers between
@@ -816,11 +812,12 @@ void ShardingUtil::XlaMarkSharding(const at::Tensor& input,
816812
// tensor from the physical device to CPU. In that case, the value
817813
// must be present on the backend device.
818814
XLA_CHECK((xtensor->CurrentDataHandle() &&
819-
xtensor->CurrentDataHandle()->HasValue()) ||
820-
device_data_node != nullptr)
815+
xtensor->CurrentDataHandle()->HasValue()))
821816
<< "Cannot shard tensor. Data does not present on any device.";
822817
std::vector<XLATensorPtr> xla_tensors{xtensor};
823-
cpu_tensor = XLAGraphExecutor::Get()->GetTensors(&xla_tensors)[0];
818+
auto tensors = XLAGraphExecutor::Get()->GetTensors(&xla_tensors);
819+
XLA_CHECK_EQ(tensors.size(), 1);
820+
cpu_tensor = tensors[0];
824821
}
825822
auto xla_data = CreateTensorsData(
826823
std::vector<at::Tensor>{cpu_tensor},

0 commit comments

Comments
 (0)