@@ -495,7 +495,7 @@ std::tuple<int64_t, int64_t> ComputeRepeatAndRepeatStride(
495495 const std::vector<int64_t >& device_elements) {
496496 int64_t first_device_id = device_elements.at (0 );
497497 int64_t first_device_id_count = 0 ;
498- for (size_t i = 0 ; i < device_elements.size (); ++i) {
498+ for (size_t i = 0 ; i < static_cast < size_t >( device_elements.size () ); ++i) {
499499 if (device_elements.at (i) == first_device_id) {
500500 ++first_device_id_count;
501501 }
@@ -505,8 +505,8 @@ std::tuple<int64_t, int64_t> ComputeRepeatAndRepeatStride(
505505 // Check if the device mesh pattern is supported.
506506 // Supported examples: [0, 1, 2] and [0, 1, 0, 1, 0, 1].
507507 // Unsupported examples: [0, 1, 2, 1, 2, 0] and [0, 1, 2, 0].
508- for (size_t repeat = 0 ; repeat < first_device_id_count; ++repeat) {
509- for (size_t device_id = 0 ; device_id < repeat_stride; ++device_id) {
508+ for (size_t repeat = 0 ; repeat < static_cast < size_t >( first_device_id_count) ; ++repeat) {
509+ for (size_t device_id = 0 ; device_id < static_cast < size_t >( repeat_stride) ; ++device_id) {
510510 ORT_ENFORCE (
511511 device_elements.at (repeat * repeat_stride + device_id) == device_elements.at (device_id),
512512 " Unsupported device mesh pattern." );
@@ -556,7 +556,7 @@ std::tuple<bool, TensorPartitionSpec> ComputeNativeSpecForTwoAxisDecomposition(
556556 // S[0], shape=[16], device=[0, 1] -> S[0]R, shape=[4, 4], device=[0, 1]
557557 std::vector<AxisPartitionSpec> dst_axis_specs;
558558 for (size_t src_axis = 0 ; src_axis < src_shape.size (); ++src_axis) {
559- if (src_axis != decomposed_axis_in_src) {
559+ if (src_axis != static_cast < size_t >( decomposed_axis_in_src) ) {
560560 // Sharding spec is copied if the axis is not decomposed.
561561 // E.g, shape [5, 6] -> Reshape -> shape [5, 3, 2]
562562 // The spec for "5" is copied.
@@ -606,7 +606,7 @@ std::tuple<bool, TensorPartitionSpec> ComputeNativeSpecForTwoAxisDecomposition(
606606 DeviceMesh dst_device_mesh;
607607 std::tie (repeats, repeat_stride) = ComputeRepeatAndRepeatStride (src_spec.device_mesh .device_mesh_elements );
608608 for (size_t src_axis = 0 ; src_axis < src_shape.size (); ++src_axis) {
609- if (src_axis != decomposed_axis_in_src) {
609+ if (src_axis != static_cast < size_t >( decomposed_axis_in_src) ) {
610610 dst_axis_specs.push_back (AxisPartitionSpec::CreateCopy (src_spec.GetAxisSpec (src_axis)));
611611 } else if (dst_shape[decomposition_axis_in_dst] == 1 ) {
612612 // S[0] -> RS[0]
@@ -660,7 +660,7 @@ std::tuple<bool, TensorPartitionSpec> ComputeNativeSpecForTwoAxisDecomposition(
660660 // Source tensor is sharded on non-decomposed axis.
661661 std::vector<AxisPartitionSpec> dst_axis_specs;
662662 for (size_t src_axis = 0 ; src_axis < src_shape.size (); ++src_axis) {
663- if (src_axis != decomposed_axis_in_src) {
663+ if (src_axis != static_cast < size_t >( decomposed_axis_in_src) ) {
664664 dst_axis_specs.push_back (AxisPartitionSpec::CreateCopy (src_spec.GetAxisSpec (src_axis)));
665665 } else {
666666 // R -> RR
0 commit comments