Skip to content

Commit e886c55

Browse files
committed
Improve naming of buffers
1 parent 26e1d67 commit e886c55

File tree

1 file changed

+45
-27
lines changed

1 file changed

+45
-27
lines changed

lib/libbackscrub.cc

+45-27
Original file line numberDiff line numberDiff line change
@@ -27,29 +27,38 @@ struct normalization_t {
2727
struct backscrub_ctx_t {
2828
// Loaded inference model
2929
std::unique_ptr<tflite::FlatBufferModel> model;
30+
3031
// Model interpreter instance
3132
std::unique_ptr<tflite::Interpreter> interpreter;
33+
3234
// Specific model type & input normalization
3335
modeltype_t modeltype;
3436
normalization_t norm;
37+
3538
// Optional callbacks with caller-provided context
3639
void (*ondebug)(void *ctx, const char *msg);
3740
void (*onprep)(void *ctx);
3841
void (*oninfer)(void *ctx);
3942
void (*onmask)(void *ctx);
4043
void *caller_ctx;
41-
// Processing state
42-
cv::Mat input;
43-
cv::Mat output;
44-
cv::Rect roidim;
45-
cv::Mat mask;
46-
cv::Mat mroi;
47-
cv::Mat ofinal;
48-
cv::Size blur;
44+
45+
// Single step variables
46+
cv::Mat input; // NN input tensors
47+
cv::Mat output; // NN output tensors
48+
cv::Mat ofinal; // NN output (post-processed mask)
49+
50+
float src_ratio; // Source image aspect ratio
51+
cv::Rect src_roidim; // Source image rect of interest
52+
cv::Mat mask_region; // Region of the final mask to operate on
53+
54+
float net_ratio; // NN input image aspect ratio
55+
cv::Rect net_roidim; // NN input image rect of interest
56+
57+
// Result stitching variables
4958
cv::Mat in_u8_bgr;
50-
cv::Rect in_roidim;
51-
float ratio;
52-
float frameratio;
59+
60+
cv::Size blur; // Size of blur on final mask
61+
cv::Mat mask; // Fully processed mask (full image)
5362
};
5463

5564
// Debug helper
@@ -190,14 +199,17 @@ void *bs_maskgen_new(
190199
) {
191200
// Allocate context
192201
backscrub_ctx_t *pctx = new backscrub_ctx_t;
202+
193203
// Take a reference so we can write tidy code with ctx.<x>
194204
backscrub_ctx_t &ctx = *pctx;
205+
195206
// Save callbacks
196207
ctx.ondebug = ondebug;
197208
ctx.onprep = onprep;
198209
ctx.oninfer = oninfer;
199210
ctx.onmask = onmask;
200211
ctx.caller_ctx = caller_ctx;
212+
201213
// Load model
202214
ctx.model = tflite::FlatBufferModel::BuildFromFile(modelname.c_str());
203215

@@ -209,18 +221,23 @@ void *bs_maskgen_new(
209221

210222
// Determine model type and normalization values
211223
ctx.modeltype = get_modeltype(modelname);
212-
ctx.norm = get_normalization(ctx.modeltype);
213224

214225
if (modeltype_t::Unknown == ctx.modeltype) {
215226
_dbg(ctx, "error: unknown model type '%s'.\n", modelname.c_str());
216227
bs_maskgen_delete(pctx);
217228
return nullptr;
218229
}
219230

231+
ctx.norm = get_normalization(ctx.modeltype);
232+
220233
// Build the interpreter
221234
tflite::ops::builtin::BuiltinOpResolver resolver;
235+
222236
// custom op for Google Meet network
223-
resolver.AddCustom("Convolution2DTransposeBias", mediapipe::tflite_operations::RegisterConvolution2DTransposeBias());
237+
resolver.AddCustom(
238+
"Convolution2DTransposeBias",
239+
mediapipe::tflite_operations::RegisterConvolution2DTransposeBias()
240+
);
224241
tflite::InterpreterBuilder builder(*ctx.model, resolver);
225242
builder(&ctx.interpreter);
226243

@@ -250,22 +267,22 @@ void *bs_maskgen_new(
250267
return nullptr;
251268
}
252269

253-
ctx.ratio = (float)ctx.input.rows / (float)ctx.input.cols;
254-
ctx.frameratio = (float)height / (float)width;
270+
ctx.net_ratio = (float)ctx.input.rows / (float)ctx.input.cols;
271+
ctx.src_ratio = (float)height / (float)width;
255272

256273
// initialize mask and model-aspect ROI in center
257-
if (ctx.frameratio < ctx.ratio) {
274+
if (ctx.src_ratio < ctx.net_ratio) {
258275
// if frame is wider than model, then use only the frame center
259-
ctx.roidim = cv::Rect((width - height / ctx.ratio) / 2, 0, height / ctx.ratio, height);
260-
ctx.in_roidim = cv::Rect(0, 0, ctx.input.cols, ctx.input.rows);
276+
ctx.src_roidim = cv::Rect((width - height / ctx.net_ratio) / 2, 0, height / ctx.net_ratio, height);
277+
ctx.net_roidim = cv::Rect(0, 0, ctx.input.cols, ctx.input.rows);
261278
} else {
262279
// if model is wider than the frame, center the frame in the model
263-
ctx.roidim = cv::Rect(0, 0, width, height);
264-
ctx.in_roidim = cv::Rect((ctx.input.cols - ctx.input.rows / ctx.frameratio) / 2, 0, ctx.input.rows / ctx.frameratio, ctx.input.rows);
280+
ctx.src_roidim = cv::Rect(0, 0, width, height);
281+
ctx.net_roidim = cv::Rect((ctx.input.cols - ctx.input.rows / ctx.src_ratio) / 2, 0, ctx.input.rows / ctx.src_ratio, ctx.input.rows);
265282
}
266283

267284
ctx.mask = cv::Mat::ones(height, width, CV_8UC1) * 255;
268-
ctx.mroi = ctx.mask(ctx.roidim);
285+
ctx.mask_region = ctx.mask(ctx.src_roidim);
269286

270287
ctx.in_u8_bgr = cv::Mat(ctx.input.rows, ctx.input.cols, CV_8UC3, cv::Scalar(0, 0, 0));
271288

@@ -301,11 +318,12 @@ bool bs_maskgen_process(void *context, cv::Mat &frame, cv::Mat &mask) {
301318
backscrub_ctx_t &ctx = *((backscrub_ctx_t *)context);
302319

303320
// map ROI
304-
cv::Mat roi = frame(ctx.roidim);
321+
cv::Mat roi = frame(ctx.src_roidim);
322+
323+
cv::Mat in_roi = ctx.in_u8_bgr(ctx.net_roidim);
324+
cv::resize(roi, in_roi, ctx.net_roidim.size());
305325

306326
cv::Mat in_u8_rgb;
307-
cv::Mat in_roi = ctx.in_u8_bgr(ctx.in_roidim);
308-
cv::resize(roi, in_roi, ctx.in_roidim.size());
309327
cv::cvtColor(ctx.in_u8_bgr, in_u8_rgb, cv::COLOR_BGR2RGB);
310328

311329
// TODO: can convert directly to float?
@@ -378,7 +396,7 @@ bool bs_maskgen_process(void *context, cv::Mat &frame, cv::Mat &mask) {
378396
* probability in [0.0, 1.0].
379397
*/
380398
for (unsigned int n = 0; n < ctx.output.total(); n++) {
381-
float exp0 = expf(tmp[2 * n ]);
399+
float exp0 = expf(tmp[2 * n ]);
382400
float exp1 = expf(tmp[2 * n + 1]);
383401
float p0 = exp0 / (exp0 + exp1);
384402
float p1 = exp1 / (exp0 + exp1);
@@ -398,10 +416,10 @@ bool bs_maskgen_process(void *context, cv::Mat &frame, cv::Mat &mask) {
398416

399417
// scale up into full-sized mask
400418
cv::Mat tmpbuf;
401-
cv::resize(ctx.ofinal(ctx.in_roidim), tmpbuf, ctx.mroi.size());
419+
cv::resize(ctx.ofinal(ctx.net_roidim), tmpbuf, ctx.mask_region.size());
402420

403421
// blur at full size for maximum smoothness
404-
cv::blur(tmpbuf, ctx.mroi, ctx.blur);
422+
cv::blur(tmpbuf, ctx.mask_region, ctx.blur);
405423

406424
// copy out
407425
mask = ctx.mask;

0 commit comments

Comments
 (0)