Skip to content

Commit 9f984b6

Browse files
authored
Merge pull request #205 from robertknight/model-interface
Refactor tests and `RunOptions` usage for compatibility with upcoming RTen release
2 parents b47516f + 78e55da commit 9f984b6

File tree

5 files changed

+205
-190
lines changed

5 files changed

+205
-190
lines changed

ocrs/src/detection.rs

Lines changed: 12 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
use anyhow::anyhow;
2-
use rten::{Dimension, FloatOperators, Model, Operators, RunOptions};
2+
use rten::{Dimension, FloatOperators, Operators, RunOptions};
33
use rten_imageproc::{find_contours, min_area_rect, simplify_polygon, RetrievalMode, RotatedRect};
44
use rten_tensor::prelude::*;
55
use rten_tensor::{NdTensor, NdTensorView, Tensor};
66

7+
use crate::model::Model;
78
use crate::preprocess::BLACK_VALUE;
89

910
/// Parameters that control post-processing of text detection model outputs.
@@ -63,7 +64,7 @@ fn find_connected_component_rects(
6364
/// Text detector which finds the oriented bounding boxes of words in an input
6465
/// image.
6566
pub struct TextDetector {
66-
model: Model,
67+
model: Box<dyn Model>,
6768
params: TextDetectorParams,
6869
input_shape: Vec<Dimension>,
6970
}
@@ -72,19 +73,13 @@ impl TextDetector {
7273
/// Initialize a DetectionModel from a trained RTen model.
7374
///
7475
/// This will fail if the model doesn't have the expected inputs or outputs.
75-
pub fn from_model(model: Model, params: TextDetectorParams) -> anyhow::Result<TextDetector> {
76-
let input_id = model
77-
.input_ids()
78-
.first()
79-
.copied()
80-
.ok_or(anyhow!("model has no inputs"))?;
81-
let input_shape = model
82-
.node_info(input_id)
83-
.and_then(|info| info.shape())
84-
.ok_or(anyhow!("model does not specify expected input shape"))?;
85-
76+
pub fn from_model(
77+
model: impl Model + 'static,
78+
params: TextDetectorParams,
79+
) -> anyhow::Result<TextDetector> {
80+
let input_shape = model.input_shape()?;
8681
Ok(TextDetector {
87-
model,
82+
model: Box::new(model),
8883
params,
8984
input_shape,
9085
})
@@ -177,21 +172,9 @@ impl TextDetector {
177172

178173
// Run text detection model to compute a probability mask indicating whether
179174
// each pixel is part of a text word or not.
180-
let text_mask: Tensor<f32> = self
181-
.model
182-
.run_one(
183-
image.view().into(),
184-
if debug {
185-
Some(RunOptions {
186-
timing: true,
187-
verbose: false,
188-
..Default::default()
189-
})
190-
} else {
191-
None
192-
},
193-
)?
194-
.try_into()?;
175+
let mut opts = RunOptions::default();
176+
opts.timing = debug;
177+
let text_mask: Tensor<f32> = self.model.run(image.view(), Some(opts))?;
195178

196179
// Resize probability mask to original input size and apply threshold to get a
197180
// binary text/not-text mask.

0 commit comments

Comments
 (0)