-
Notifications
You must be signed in to change notification settings - Fork 137
Question not an issue #73
Description
Hello, I followed this link https://github.com/cedrickchee/pytorch-android to build an android app with a ResNet18 that does binary classification. I changed the app to take a still image when a button is clicked. Currently my model doesn't work well, I'm assuming because I'm not doing the same pre processing as in pytorch, and I don't know how to do it on the android app. On PyTorch I had transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), and I'm assuming I should do the same thing on android for the model to work well. I also don't know what the b_mean, g_mean, r_mean in the cpp file in the original code below mean. Should I replace those with the means that I need [0.485, 0.456, 0.406] or are these just to get the format rgb? Thank you so much for help. `
float b_mean = 104.00698793f;
float g_mean = 116.66876762f;
float r_mean = 122.67891434f;
auto b_i = 0 * IMG_H * IMG_W + j * IMG_W + i;
auto g_i = 1 * IMG_H * IMG_W + j * IMG_W + i;
auto r_i = 2 * IMG_H * IMG_W + j * IMG_W + i;
if (infer_HWC) {
b_i = (j * IMG_W + i) * IMG_C;
g_i = (j * IMG_W + i) * IMG_C + 1;
r_i = (j * IMG_W + i) * IMG_C + 2;
}
input_data[r_i] = -r_mean + (float) ((float) min(255., max(0., (float) (y + 1.402 * (v - 128)))));
input_data[g_i] = -g_mean + (float) ((float) min(255., max(0., (float) (y - 0.34414 * (u - 128) - 0.71414 * (v - 128)))));
input_data[b_i] = -b_mean + (float) ((float) min(255., max(0., (float) (y + 1.772 * (u - v)))));
`