@@ -772,15 +772,11 @@ void ShardingUtil::XlaMarkSharding(const at::Tensor& input,
772
772
xtensor->shape (), static_cast <XlaDeviceType>(
773
773
xtensor->GetDevice ().type ())));
774
774
775
- // For Non DeviceData IR values, we directly attach the sharding spec
776
- // to the xtensor.
777
- const DeviceData* device_data_node = nullptr ;
775
+ // For IR values, we directly attach the sharding spec to the xtensor. In case
776
+ // it is a DeviceData IR node, we similarly add a custom sharding IR node.
778
777
if (xtensor->CurrentIrValue ()) {
779
- device_data_node = DeviceData::Cast (xtensor->CurrentIrValue ().node .get ());
780
- if (!device_data_node) {
781
- tensor_methods::custom_sharding_ (xtensor, new_sharding_spec);
782
- return ;
783
- }
778
+ tensor_methods::custom_sharding_ (xtensor, new_sharding_spec);
779
+ return ;
784
780
}
785
781
786
782
// For data, we need to deal with the data transfers between
@@ -816,11 +812,12 @@ void ShardingUtil::XlaMarkSharding(const at::Tensor& input,
816
812
// tensor from the physical device to CPU. In that case, the value
817
813
// must be present on the backend device.
818
814
XLA_CHECK ((xtensor->CurrentDataHandle () &&
819
- xtensor->CurrentDataHandle ()->HasValue ()) ||
820
- device_data_node != nullptr )
815
+ xtensor->CurrentDataHandle ()->HasValue ()))
821
816
<< " Cannot shard tensor. Data does not present on any device." ;
822
817
std::vector<XLATensorPtr> xla_tensors{xtensor};
823
- cpu_tensor = XLAGraphExecutor::Get ()->GetTensors (&xla_tensors)[0 ];
818
+ auto tensors = XLAGraphExecutor::Get ()->GetTensors (&xla_tensors);
819
+ XLA_CHECK_EQ (tensors.size (), 1 );
820
+ cpu_tensor = tensors[0 ];
824
821
}
825
822
auto xla_data = CreateTensorsData (
826
823
std::vector<at::Tensor>{cpu_tensor},
0 commit comments