Skip to content

Commit 9ae3d8b

Browse files
committed
Use a trait to abstract the inference engine and mock it in tests
In the next rten version some of the current APIs used to construct models are going to be removed (see robertknight/rten#1014 and in particular robertknight/rten#1018). Hence this project needs to use a different way to create the fake models for unit tests. This commit introduces a `Model` trait as an abstraction of the runtime. This is implemented by `rten::Model` for non-test usage and by fake model types in tests. The new fakes have the same behavior as the previous "real" model.
1 parent b47516f commit 9ae3d8b

File tree

5 files changed

+214
-190
lines changed

5 files changed

+214
-190
lines changed

ocrs/src/detection.rs

Lines changed: 21 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
use anyhow::anyhow;
2-
use rten::{Dimension, FloatOperators, Model, Operators, RunOptions};
2+
use rten::{Dimension, FloatOperators, Operators, RunOptions};
33
use rten_imageproc::{find_contours, min_area_rect, simplify_polygon, RetrievalMode, RotatedRect};
44
use rten_tensor::prelude::*;
55
use rten_tensor::{NdTensor, NdTensorView, Tensor};
66

7+
use crate::model::Model;
78
use crate::preprocess::BLACK_VALUE;
89

910
/// Parameters that control post-processing of text detection model outputs.
@@ -63,7 +64,7 @@ fn find_connected_component_rects(
6364
/// Text detector which finds the oriented bounding boxes of words in an input
6465
/// image.
6566
pub struct TextDetector {
66-
model: Model,
67+
model: Box<dyn Model>,
6768
params: TextDetectorParams,
6869
input_shape: Vec<Dimension>,
6970
}
@@ -72,19 +73,13 @@ impl TextDetector {
7273
/// Initialize a DetectionModel from a trained RTen model.
7374
///
7475
/// This will fail if the model doesn't have the expected inputs or outputs.
75-
pub fn from_model(model: Model, params: TextDetectorParams) -> anyhow::Result<TextDetector> {
76-
let input_id = model
77-
.input_ids()
78-
.first()
79-
.copied()
80-
.ok_or(anyhow!("model has no inputs"))?;
81-
let input_shape = model
82-
.node_info(input_id)
83-
.and_then(|info| info.shape())
84-
.ok_or(anyhow!("model does not specify expected input shape"))?;
85-
76+
pub fn from_model(
77+
model: impl Model + 'static,
78+
params: TextDetectorParams,
79+
) -> anyhow::Result<TextDetector> {
80+
let input_shape = model.input_shape()?;
8681
Ok(TextDetector {
87-
model,
82+
model: Box::new(model),
8883
params,
8984
input_shape,
9085
})
@@ -177,21 +172,18 @@ impl TextDetector {
177172

178173
// Run text detection model to compute a probability mask indicating whether
179174
// each pixel is part of a text word or not.
180-
let text_mask: Tensor<f32> = self
181-
.model
182-
.run_one(
183-
image.view().into(),
184-
if debug {
185-
Some(RunOptions {
186-
timing: true,
187-
verbose: false,
188-
..Default::default()
189-
})
190-
} else {
191-
None
192-
},
193-
)?
194-
.try_into()?;
175+
let text_mask: Tensor<f32> = self.model.run(
176+
image.view(),
177+
if debug {
178+
Some(RunOptions {
179+
timing: true,
180+
verbose: false,
181+
..Default::default()
182+
})
183+
} else {
184+
None
185+
},
186+
)?;
195187

196188
// Resize probability mask to original input size and apply threshold to get a
197189
// binary text/not-text mask.

0 commit comments

Comments
 (0)