@@ -71,45 +71,44 @@ std::vector<torch::jit::IValue> ResnetHandler::Preprocess(
71
71
72
72
cv::Mat image = cv::imdecode (data_it->second , cv::IMREAD_COLOR);
73
73
74
- cv::cvtColor (image, image, cv::COLOR_BGR2RGB);
75
-
76
74
// Check if the image was successfully decoded
77
75
if (image.empty ()) {
78
76
std::cerr << " Failed to decode the image." << std::endl;
79
77
}
80
78
81
- // Resize
82
- const int newWidth = 256 , newHeight = 256 ;
83
- cv::Mat resizedImage;
84
- cv::resize (image, resizedImage, cv::Size (newWidth, newHeight));
85
-
86
79
// Crop image
87
- const int cropSize = 224 ;
88
- const int offsetW = (resizedImage.cols - cropSize) / 2 ;
89
- const int offsetH = (resizedImage.rows - cropSize) / 2 ;
80
+ const int rows = img.rows ;
81
+ const int cols = img.cols ;
82
+
83
+ const int cropSize = std::min (rows, cols);
84
+ const int offsetW = (cols - cropSize) / 2 ;
85
+ const int offsetH = (rows - cropSize) / 2 ;
90
86
91
87
const cv::Rect roi (offsetW, offsetH, cropSize, cropSize);
92
- cv::Mat croppedImage = resizedImage (roi). clone ( );
88
+ image = image (roi);
93
89
94
- // Convert the OpenCV image to a torch tensor
95
- // Drift in cropped image
96
- // Vision Crop: 114, 118, 115, 102, 106, 97
97
- // OpenCV Crop: 113, 118, 114, 100, 106, 97
98
- torch::TensorOptions options (torch::kByte );
99
- torch::Tensor tensorImage = torch::from_blob (
100
- croppedImage.data ,
101
- {croppedImage.rows , croppedImage.cols , croppedImage.channels ()},
102
- options);
90
+ // Resize
91
+ cv::resize (image, image, cv::Size (224 , 224 ));
103
92
93
+ // Convert BGR to RGB format
94
+ cv::cvtColor (image, image, cv::COLOR_BGR2RGB);
95
+
96
+ image.convertTo (img, CV_32FC3, 1 / 255.0 );
97
+
98
+ // Convert the OpenCV image to a torch tensor
99
+ torch::Tensor tensorImage = torch::from_blob (image.data , {image.rows , image.cols , 3 }, c10::kFloat );
104
100
tensorImage = tensorImage.permute ({2 , 0 , 1 });
105
- tensorImage = tensorImage. to (torch:: kFloat32 ) / 255.0 ;
101
+ tensorImage. unsqueeze_ ( 0 ) ;
106
102
107
103
// Normalize
108
- torch::Tensor normalizedTensorImage =
109
- torch::data::transforms::Normalize<>(
110
- {0.485 , 0.456 , 0.406 }, {0.229 , 0.224 , 0.225 })(tensorImage);
111
- normalizedTensorImage.clone ();
112
- batch_tensors.emplace_back (normalizedTensorImage.to (*device));
104
+ std::vector<double > norm_mean = {0.485 , 0.456 , 0.406 };
105
+ std::vector<double > norm_std = {0.229 , 0.224 , 0.225 };
106
+
107
+ tensorImage =
108
+ torch::data::transforms::Normalize<>(norm_mean, norm_std)(tensorImage);
109
+
110
+ tensorImage.clone ();
111
+ batch_tensors.emplace_back (tensorImage.to (*device));
113
112
idx_to_req_id.second [idx++] = request.request_id ;
114
113
} else if (dtype_it->second == " List" ) {
115
114
// case3: the image is a list
0 commit comments