@@ -27,29 +27,38 @@ struct normalization_t {
27
27
struct backscrub_ctx_t {
28
28
// Loaded inference model
29
29
std::unique_ptr<tflite::FlatBufferModel> model;
30
+
30
31
// Model interpreter instance
31
32
std::unique_ptr<tflite::Interpreter> interpreter;
33
+
32
34
// Specific model type & input normalization
33
35
modeltype_t modeltype;
34
36
normalization_t norm;
37
+
35
38
// Optional callbacks with caller-provided context
36
39
void (*ondebug)(void *ctx, const char *msg);
37
40
void (*onprep)(void *ctx);
38
41
void (*oninfer)(void *ctx);
39
42
void (*onmask)(void *ctx);
40
43
void *caller_ctx;
41
- // Processing state
42
- cv::Mat input;
43
- cv::Mat output;
44
- cv::Rect roidim;
45
- cv::Mat mask;
46
- cv::Mat mroi;
47
- cv::Mat ofinal;
48
- cv::Size blur;
44
+
45
+ // Single step variables
46
+ cv::Mat input; // NN input tensors
47
+ cv::Mat output; // NN output tensors
48
+ cv::Mat ofinal; // NN output (post-processed mask)
49
+
50
+ float src_ratio; // Source image aspect ratio
51
+ cv::Rect src_roidim; // Source image rect of interest
52
+ cv::Mat mask_region; // Region of the final mask to operate on
53
+
54
+ float net_ratio; // NN input image aspect ratio
55
+ cv::Rect net_roidim; // NN input image rect of interest
56
+
57
+ // Result stitching variables
49
58
cv::Mat in_u8_bgr;
50
- cv:: Rect in_roidim;
51
- float ratio;
52
- float frameratio;
59
+
60
+ cv:: Size blur; // Size of blur on final mask
61
+ cv::Mat mask; // Fully processed mask (full image)
53
62
};
54
63
55
64
// Debug helper
@@ -190,14 +199,17 @@ void *bs_maskgen_new(
190
199
) {
191
200
// Allocate context
192
201
backscrub_ctx_t *pctx = new backscrub_ctx_t ;
202
+
193
203
// Take a reference so we can write tidy code with ctx.<x>
194
204
backscrub_ctx_t &ctx = *pctx;
205
+
195
206
// Save callbacks
196
207
ctx.ondebug = ondebug;
197
208
ctx.onprep = onprep;
198
209
ctx.oninfer = oninfer;
199
210
ctx.onmask = onmask;
200
211
ctx.caller_ctx = caller_ctx;
212
+
201
213
// Load model
202
214
ctx.model = tflite::FlatBufferModel::BuildFromFile (modelname.c_str ());
203
215
@@ -209,18 +221,23 @@ void *bs_maskgen_new(
209
221
210
222
// Determine model type and normalization values
211
223
ctx.modeltype = get_modeltype (modelname);
212
- ctx.norm = get_normalization (ctx.modeltype );
213
224
214
225
if (modeltype_t ::Unknown == ctx.modeltype ) {
215
226
_dbg (ctx, " error: unknown model type '%s'.\n " , modelname.c_str ());
216
227
bs_maskgen_delete (pctx);
217
228
return nullptr ;
218
229
}
219
230
231
+ ctx.norm = get_normalization (ctx.modeltype );
232
+
220
233
// Build the interpreter
221
234
tflite::ops::builtin::BuiltinOpResolver resolver;
235
+
222
236
// custom op for Google Meet network
223
- resolver.AddCustom (" Convolution2DTransposeBias" , mediapipe::tflite_operations::RegisterConvolution2DTransposeBias ());
237
+ resolver.AddCustom (
238
+ " Convolution2DTransposeBias" ,
239
+ mediapipe::tflite_operations::RegisterConvolution2DTransposeBias ()
240
+ );
224
241
tflite::InterpreterBuilder builder (*ctx.model , resolver);
225
242
builder (&ctx.interpreter );
226
243
@@ -250,22 +267,22 @@ void *bs_maskgen_new(
250
267
return nullptr ;
251
268
}
252
269
253
- ctx.ratio = (float )ctx.input .rows / (float )ctx.input .cols ;
254
- ctx.frameratio = (float )height / (float )width;
270
+ ctx.net_ratio = (float )ctx.input .rows / (float )ctx.input .cols ;
271
+ ctx.src_ratio = (float )height / (float )width;
255
272
256
273
// initialize mask and model-aspect ROI in center
257
- if (ctx.frameratio < ctx.ratio ) {
274
+ if (ctx.src_ratio < ctx.net_ratio ) {
258
275
// if frame is wider than model, then use only the frame center
259
- ctx.roidim = cv::Rect ((width - height / ctx.ratio ) / 2 , 0 , height / ctx.ratio , height);
260
- ctx.in_roidim = cv::Rect (0 , 0 , ctx.input .cols , ctx.input .rows );
276
+ ctx.src_roidim = cv::Rect ((width - height / ctx.net_ratio ) / 2 , 0 , height / ctx.net_ratio , height);
277
+ ctx.net_roidim = cv::Rect (0 , 0 , ctx.input .cols , ctx.input .rows );
261
278
} else {
262
279
// if model is wider than the frame, center the frame in the model
263
- ctx.roidim = cv::Rect (0 , 0 , width, height);
264
- ctx.in_roidim = cv::Rect ((ctx.input .cols - ctx.input .rows / ctx.frameratio ) / 2 , 0 , ctx.input .rows / ctx.frameratio , ctx.input .rows );
280
+ ctx.src_roidim = cv::Rect (0 , 0 , width, height);
281
+ ctx.net_roidim = cv::Rect ((ctx.input .cols - ctx.input .rows / ctx.src_ratio ) / 2 , 0 , ctx.input .rows / ctx.src_ratio , ctx.input .rows );
265
282
}
266
283
267
284
ctx.mask = cv::Mat::ones (height, width, CV_8UC1) * 255 ;
268
- ctx.mroi = ctx.mask (ctx.roidim );
285
+ ctx.mask_region = ctx.mask (ctx.src_roidim );
269
286
270
287
ctx.in_u8_bgr = cv::Mat (ctx.input .rows , ctx.input .cols , CV_8UC3, cv::Scalar (0 , 0 , 0 ));
271
288
@@ -301,11 +318,12 @@ bool bs_maskgen_process(void *context, cv::Mat &frame, cv::Mat &mask) {
301
318
backscrub_ctx_t &ctx = *((backscrub_ctx_t *)context);
302
319
303
320
// map ROI
304
- cv::Mat roi = frame (ctx.roidim );
321
+ cv::Mat roi = frame (ctx.src_roidim );
322
+
323
+ cv::Mat in_roi = ctx.in_u8_bgr (ctx.net_roidim );
324
+ cv::resize (roi, in_roi, ctx.net_roidim .size ());
305
325
306
326
cv::Mat in_u8_rgb;
307
- cv::Mat in_roi = ctx.in_u8_bgr (ctx.in_roidim );
308
- cv::resize (roi, in_roi, ctx.in_roidim .size ());
309
327
cv::cvtColor (ctx.in_u8_bgr , in_u8_rgb, cv::COLOR_BGR2RGB);
310
328
311
329
// TODO: can convert directly to float?
@@ -378,7 +396,7 @@ bool bs_maskgen_process(void *context, cv::Mat &frame, cv::Mat &mask) {
378
396
* probability in [0.0, 1.0].
379
397
*/
380
398
for (unsigned int n = 0 ; n < ctx.output .total (); n++) {
381
- float exp0 = expf (tmp[2 * n ]);
399
+ float exp0 = expf (tmp[2 * n ]);
382
400
float exp1 = expf (tmp[2 * n + 1 ]);
383
401
float p0 = exp0 / (exp0 + exp1);
384
402
float p1 = exp1 / (exp0 + exp1);
@@ -398,10 +416,10 @@ bool bs_maskgen_process(void *context, cv::Mat &frame, cv::Mat &mask) {
398
416
399
417
// scale up into full-sized mask
400
418
cv::Mat tmpbuf;
401
- cv::resize (ctx.ofinal (ctx.in_roidim ), tmpbuf, ctx.mroi .size ());
419
+ cv::resize (ctx.ofinal (ctx.net_roidim ), tmpbuf, ctx.mask_region .size ());
402
420
403
421
// blur at full size for maximum smoothness
404
- cv::blur (tmpbuf, ctx.mroi , ctx.blur );
422
+ cv::blur (tmpbuf, ctx.mask_region , ctx.blur );
405
423
406
424
// copy out
407
425
mask = ctx.mask ;
0 commit comments