Skip to content

Commit e0715b9

Browse files
committed
Support pytorch_half_pixel value for coordinate_transform_mode
Support the `pytorch_half_pixel` transform mode in the `Resize` operator. See https://onnx.ai/onnx/operators/onnx__Resize.html#attributes. There are some upstream spec issues around this value, linked in the comments. The implementation matches the current spec and ORT.
1 parent 4329a65 commit e0715b9

File tree

6 files changed

+46
-13
lines changed

6 files changed

+46
-13
lines changed

rten-convert/rten_convert/schema_generated.py

+1
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ class CoordTransformMode(object):
143143
HalfPixel = 0
144144
Asymmetric = 1
145145
AlignCorners = 2
146+
PytorchHalfPixel = 3
146147

147148

148149
class NearestMode(object):

src/model_builder.rs

+3
Original file line numberDiff line numberDiff line change
@@ -796,6 +796,9 @@ impl<'mb, 'a> GraphBuilder<'mb, 'a> {
796796
CoordTransformMode::Asymmetric => sg::CoordTransformMode::Asymmetric,
797797
CoordTransformMode::HalfPixel => sg::CoordTransformMode::HalfPixel,
798798
CoordTransformMode::AlignCorners => sg::CoordTransformMode::AlignCorners,
799+
CoordTransformMode::PytorchHalfPixel => {
800+
sg::CoordTransformMode::PytorchHalfPixel
801+
}
799802
};
800803
let nearest_mode = match args.nearest_mode {
801804
NearestMode::Ceil => sg::NearestMode::Ceil,

src/op_registry.rs

+1
Original file line numberDiff line numberDiff line change
@@ -817,6 +817,7 @@ impl_read_op!(Resize, attrs_as_resize_attrs, |attrs: sg::ResizeAttrs| {
817817
sg::CoordTransformMode::Asymmetric => CoordTransformMode::Asymmetric,
818818
sg::CoordTransformMode::HalfPixel => CoordTransformMode::HalfPixel,
819819
sg::CoordTransformMode::AlignCorners => CoordTransformMode::AlignCorners,
820+
sg::CoordTransformMode::PytorchHalfPixel => CoordTransformMode::PytorchHalfPixel,
820821
_ => CoordTransformMode::default(),
821822
};
822823

src/ops/resize.rs

+27-7
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,18 @@ fn input_coord(
5151
Ctm::AlignCorners => {
5252
dest_coord as f32 * (length_original - 1) as f32 / (length_resized - 1) as f32
5353
}
54+
Ctm::PytorchHalfPixel => {
55+
// There are some queries over this transform mode, see
56+
// https://github.com/onnx/onnx/issues/4275 (applies to cubic interpolation only)
57+
// and https://github.com/onnx/onnx/issues/4276 (comparison with
58+
// PyTorch behavior). This implementation does however match
59+
// ONNX Runtime (https://github.com/microsoft/onnxruntime/blob/24620e70d9f14956a0dc84bb8a332dcd64c95a94/onnxruntime/core/providers/cpu/tensor/upsamplebase.h#L331)
60+
if length_resized > 1 {
61+
(dest_coord as f32 + 0.5) / scale - 0.5
62+
} else {
63+
0.
64+
}
65+
}
5466
}
5567
}
5668

@@ -73,6 +85,7 @@ pub enum CoordTransformMode {
7385
HalfPixel,
7486
Asymmetric,
7587
AlignCorners,
88+
PytorchHalfPixel,
7689
}
7790

7891
const CHAN_GROUP_SIZE: usize = 4;
@@ -648,18 +661,25 @@ mod tests {
648661
coord_transform_mode: None,
649662
expected: Tensor::from_data(&[1, 1, 0, 0], vec![]),
650663
},
651-
// Scale width and height by 0.5x
664+
// Scale to output width and height less than 2, using `HalfPixel`
665+
// `coord_transform_mode`.
666+
//
667+
// When the output size is < 2, `half_pixel` and `pytorch_half_pixel`
668+
// produce different results. Otherwise they are the same.
652669
Case {
653670
image,
654671
scales: vec![1., 1., 0.5, 0.5],
655-
coord_transform_mode: None,
656-
657-
// OpenCV and PyTorch produce different results for this case.
658-
// This result matches OpenCV. This relates to the `half_pixel`
659-
// vs `pytorch_half_pixel` values for the `coordinate_transformation_mode`
660-
// attribute in the ONNX op.
672+
coord_transform_mode: Some(CoordTransformMode::HalfPixel),
661673
expected: Tensor::from_data(&[1, 1, 1, 1], vec![0.5]),
662674
},
675+
// Scale to output width and height less than 2, using `PytorchHalfPixel`
676+
// `coord_transform_mode`.
677+
Case {
678+
image,
679+
scales: vec![1., 1., 0.5, 0.5],
680+
coord_transform_mode: Some(CoordTransformMode::PytorchHalfPixel),
681+
expected: Tensor::from_data(&[1, 1, 1, 1], vec![0.2]),
682+
},
663683
// Scale width and height by 1x
664684
Case {
665685
image,

src/schema.fbs

+2-1
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,8 @@ enum DataType: ubyte {
155155
enum CoordTransformMode: ubyte {
156156
HalfPixel,
157157
Asymmetric,
158-
AlignCorners
158+
AlignCorners,
159+
PytorchHalfPixel,
159160
}
160161

161162
// Rounding modes supported by Resize operator when `ResizeMode` is `Nearest`.

src/schema_generated.rs

+12-5
Original file line numberDiff line numberDiff line change
@@ -835,16 +835,17 @@ pub const ENUM_MIN_COORD_TRANSFORM_MODE: u8 = 0;
835835
since = "2.0.0",
836836
note = "Use associated constants instead. This will no longer be generated in 2021."
837837
)]
838-
pub const ENUM_MAX_COORD_TRANSFORM_MODE: u8 = 2;
838+
pub const ENUM_MAX_COORD_TRANSFORM_MODE: u8 = 3;
839839
#[deprecated(
840840
since = "2.0.0",
841841
note = "Use associated constants instead. This will no longer be generated in 2021."
842842
)]
843843
#[allow(non_camel_case_types)]
844-
pub const ENUM_VALUES_COORD_TRANSFORM_MODE: [CoordTransformMode; 3] = [
844+
pub const ENUM_VALUES_COORD_TRANSFORM_MODE: [CoordTransformMode; 4] = [
845845
CoordTransformMode::HalfPixel,
846846
CoordTransformMode::Asymmetric,
847847
CoordTransformMode::AlignCorners,
848+
CoordTransformMode::PytorchHalfPixel,
848849
];
849850

850851
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
@@ -855,17 +856,23 @@ impl CoordTransformMode {
855856
pub const HalfPixel: Self = Self(0);
856857
pub const Asymmetric: Self = Self(1);
857858
pub const AlignCorners: Self = Self(2);
859+
pub const PytorchHalfPixel: Self = Self(3);
858860

859861
pub const ENUM_MIN: u8 = 0;
860-
pub const ENUM_MAX: u8 = 2;
861-
pub const ENUM_VALUES: &'static [Self] =
862-
&[Self::HalfPixel, Self::Asymmetric, Self::AlignCorners];
862+
pub const ENUM_MAX: u8 = 3;
863+
pub const ENUM_VALUES: &'static [Self] = &[
864+
Self::HalfPixel,
865+
Self::Asymmetric,
866+
Self::AlignCorners,
867+
Self::PytorchHalfPixel,
868+
];
863869
/// Returns the variant's name or "" if unknown.
864870
pub fn variant_name(self) -> Option<&'static str> {
865871
match self {
866872
Self::HalfPixel => Some("HalfPixel"),
867873
Self::Asymmetric => Some("Asymmetric"),
868874
Self::AlignCorners => Some("AlignCorners"),
875+
Self::PytorchHalfPixel => Some("PytorchHalfPixel"),
869876
_ => None,
870877
}
871878
}

0 commit comments

Comments
 (0)