Skip to content

Commit 4a9d16f

Browse files
author
jagadeesh
committed
fix pre-processing
Signed-off-by: jagadeesh <[email protected]>
1 parent a781796 commit 4a9d16f

File tree

1 file changed

+25
-26
lines changed

1 file changed

+25
-26
lines changed

cpp/src/examples/image_classifier/resnet-18/resnet-18_handler.cc

+25-26
Original file line numberDiff line numberDiff line change
@@ -71,45 +71,44 @@ std::vector<torch::jit::IValue> ResnetHandler::Preprocess(
7171

7272
cv::Mat image = cv::imdecode(data_it->second, cv::IMREAD_COLOR);
7373

74-
cv::cvtColor(image, image, cv::COLOR_BGR2RGB);
75-
7674
// Check if the image was successfully decoded
7775
if (image.empty()) {
7876
std::cerr << "Failed to decode the image." << std::endl;
7977
}
8078

81-
// Resize
82-
const int newWidth = 256, newHeight = 256;
83-
cv::Mat resizedImage;
84-
cv::resize(image, resizedImage, cv::Size(newWidth, newHeight));
85-
8679
// Crop image
87-
const int cropSize = 224;
88-
const int offsetW = (resizedImage.cols - cropSize) / 2;
89-
const int offsetH = (resizedImage.rows - cropSize) / 2;
80+
const int rows = img.rows;
81+
const int cols = img.cols;
82+
83+
const int cropSize = std::min(rows, cols);
84+
const int offsetW = (cols - cropSize) / 2;
85+
const int offsetH = (rows - cropSize) / 2;
9086

9187
const cv::Rect roi(offsetW, offsetH, cropSize, cropSize);
92-
cv::Mat croppedImage = resizedImage(roi).clone();
88+
image = image(roi);
9389

94-
// Convert the OpenCV image to a torch tensor
95-
// Drift in cropped image
96-
// Vision Crop: 114, 118, 115, 102, 106, 97
97-
// OpenCV Crop: 113, 118, 114, 100, 106, 97
98-
torch::TensorOptions options(torch::kByte);
99-
torch::Tensor tensorImage = torch::from_blob(
100-
croppedImage.data,
101-
{croppedImage.rows, croppedImage.cols, croppedImage.channels()},
102-
options);
90+
// Resize
91+
cv::resize(image, image, cv::Size(224, 224));
10392

93+
// Convert BGR to RGB format
94+
cv::cvtColor(image, image, cv::COLOR_BGR2RGB);
95+
96+
image.convertTo(img, CV_32FC3, 1 / 255.0);
97+
98+
// Convert the OpenCV image to a torch tensor
99+
torch::Tensor tensorImage = torch::from_blob(image.data, {image.rows, image.cols, 3}, c10::kFloat);
104100
tensorImage = tensorImage.permute({2, 0, 1});
105-
tensorImage = tensorImage.to(torch::kFloat32) / 255.0;
101+
tensorImage.unsqueeze_(0);
106102

107103
// Normalize
108-
torch::Tensor normalizedTensorImage =
109-
torch::data::transforms::Normalize<>(
110-
{0.485, 0.456, 0.406}, {0.229, 0.224, 0.225})(tensorImage);
111-
normalizedTensorImage.clone();
112-
batch_tensors.emplace_back(normalizedTensorImage.to(*device));
104+
std::vector<double> norm_mean = {0.485, 0.456, 0.406};
105+
std::vector<double> norm_std = {0.229, 0.224, 0.225};
106+
107+
tensorImage =
108+
torch::data::transforms::Normalize<>(norm_mean, norm_std)(tensorImage);
109+
110+
tensorImage.clone();
111+
batch_tensors.emplace_back(tensorImage.to(*device));
113112
idx_to_req_id.second[idx++] = request.request_id;
114113
} else if (dtype_it->second == "List") {
115114
// case3: the image is a list

0 commit comments

Comments
 (0)