Skip to content

Commit 85b6eb1

Browse files
committed
fix: 增强错误处理
1 parent 7a6909a commit 85b6eb1

7 files changed

Lines changed: 73 additions & 71 deletions

File tree

examples/debug_det.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ fn main() {
4848

4949
// 预处理
5050
let params = NormalizeParams::paddle_det();
51-
let input = preprocess_for_det(&scaled, &params);
51+
let input = preprocess_for_det(&scaled, &params).expect("预处理失败");
5252
println!("输入张量形状: {:?}", input.shape());
5353

5454
// 推理

examples/debug_rec.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ fn main() {
4848
// 预处理
4949
let target_height = 48u32;
5050
let params = NormalizeParams::paddle_rec();
51-
let input = preprocess_for_rec(&image, target_height, &params);
51+
let input = preprocess_for_rec(&image, target_height, &params).expect("预处理失败");
5252
println!("输入张量形状: {:?}", input.shape());
5353

5454
// 推理

src/det.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ impl DetModel {
248248
let (scaled_width, scaled_height) = scaled.dimensions();
249249

250250
// Preprocess
251-
let input = preprocess_for_det(&scaled, &self.normalize_params);
251+
let input = preprocess_for_det(&scaled, &self.normalize_params)?;
252252

253253
// Inference (using dynamic shape)
254254
let output = self.engine.run_dynamic(input.view().into_dyn())?;

src/engine.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ impl OcrEngine {
405405
}
406406

407407
// 2. Batch recognition
408-
let (mut images, boxes): (Vec<DynamicImage>, Vec<TextBox>) = detections.into_iter().unzip();
408+
let (images, boxes): (Vec<DynamicImage>, Vec<TextBox>) = detections.into_iter().unzip();
409409

410410
let rec_results = if self.config.enable_parallel && images.len() > 4 {
411411
// Parallel recognition: for multiple text regions, use rayon for parallel processing

src/ori.rs

Lines changed: 10 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ impl OriModel {
240240
let (class_idx, &confidence) = scores
241241
.iter()
242242
.enumerate()
243-
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
243+
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
244244
.ok_or_else(|| {
245245
OcrError::PostprocessError(
246246
"Orientation model output has no valid scores".to_string(),
@@ -255,7 +255,10 @@ impl OriModel {
255255
/// Convert class index to angle in degrees (best effort mapping)
256256
fn class_to_angle(num_classes: usize, class_idx: usize, class_angles: &[i32]) -> i32 {
257257
if class_angles.len() == num_classes {
258-
return class_angles.get(class_idx).copied().unwrap_or(class_idx as i32);
258+
return class_angles
259+
.get(class_idx)
260+
.copied()
261+
.unwrap_or(class_idx as i32);
259262
}
260263

261264
match num_classes {
@@ -282,10 +285,7 @@ fn softmax(scores: &[f32]) -> Vec<f32> {
282285
return Vec::new();
283286
}
284287

285-
let max_score = scores
286-
.iter()
287-
.cloned()
288-
.fold(f32::NEG_INFINITY, f32::max);
288+
let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
289289
let exp_scores: Vec<f32> = scores.iter().map(|&s| (s - max_score).exp()).collect();
290290
let sum_exp: f32 = exp_scores.iter().sum();
291291

@@ -342,11 +342,7 @@ fn preprocess_for_ori(
342342
let scale = resize_shorter as f32 / shorter;
343343
let new_w = (w as f32 * scale).round().max(1.0) as u32;
344344
let new_h = (h as f32 * scale).round().max(1.0) as u32;
345-
let resized = img.resize_exact(
346-
new_w,
347-
new_h,
348-
image::imageops::FilterType::Lanczos3,
349-
);
345+
let resized = img.resize_exact(new_w, new_h, image::imageops::FilterType::Lanczos3);
350346

351347
if new_w < target_width || new_h < target_height {
352348
resized.resize_exact(
@@ -365,12 +361,7 @@ fn preprocess_for_ori(
365361
let rgb_img = processed.to_rgb8();
366362
let (proc_w, proc_h) = processed.dimensions();
367363

368-
let mut input = Array4::<f32>::zeros((
369-
1,
370-
3,
371-
target_height as usize,
372-
target_width as usize,
373-
));
364+
let mut input = Array4::<f32>::zeros((1, 3, target_height as usize, target_width as usize));
374365

375366
let max_y = proc_h.min(target_height) as usize;
376367
let max_x = proc_w.min(target_width) as usize;
@@ -458,15 +449,8 @@ mod tests {
458449
fn test_preprocess_for_ori_shape() {
459450
let img = DynamicImage::new_rgb8(100, 32);
460451
let params = NormalizeParams::paddle_det();
461-
let tensor = preprocess_for_ori(
462-
&img,
463-
224,
464-
224,
465-
256,
466-
OriPreprocessMode::Doc,
467-
&params,
468-
)
469-
.unwrap();
452+
let tensor =
453+
preprocess_for_ori(&img, 224, 224, 256, OriPreprocessMode::Doc, &params).unwrap();
470454
assert_eq!(tensor.shape(), &[1, 3, 224, 224]);
471455
}
472456
}

src/preprocess.rs

Lines changed: 53 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
use image::{DynamicImage, GenericImageView, RgbImage};
66
use ndarray::{Array4, ArrayBase, Dim, OwnedRepr};
77

8+
use crate::error::{OcrError, OcrResult};
9+
810
/// Image normalization parameters
911
#[derive(Debug, Clone)]
1012
pub struct NormalizeParams {
@@ -51,12 +53,12 @@ pub fn get_padded_size(size: u32) -> u32 {
5153
/// Scale image to specified maximum side length
5254
///
5355
/// Maintains aspect ratio, scales longest side to max_side_len
54-
pub fn resize_to_max_side(img: &DynamicImage, max_side_len: u32) -> DynamicImage {
56+
pub fn resize_to_max_side(img: &DynamicImage, max_side_len: u32) -> OcrResult<DynamicImage> {
5557
let (w, h) = img.dimensions();
5658
let max_dim = w.max(h);
5759

5860
if max_dim <= max_side_len {
59-
return img.clone();
61+
return Ok(img.clone());
6062
}
6163

6264
let scale = max_side_len as f64 / max_dim as f64;
@@ -69,11 +71,11 @@ pub fn resize_to_max_side(img: &DynamicImage, max_side_len: u32) -> DynamicImage
6971
/// Scale image to specified height (for recognition model)
7072
///
7173
/// Scales maintaining aspect ratio
72-
pub fn resize_to_height(img: &DynamicImage, target_height: u32) -> DynamicImage {
74+
pub fn resize_to_height(img: &DynamicImage, target_height: u32) -> OcrResult<DynamicImage> {
7375
let (w, h) = img.dimensions();
7476

7577
if h == target_height {
76-
return img.clone();
78+
return Ok(img.clone());
7779
}
7880

7981
let scale = target_height as f64 / h as f64;
@@ -84,31 +86,44 @@ pub fn resize_to_height(img: &DynamicImage, target_height: u32) -> DynamicImage
8486

8587
/// Fast image resizing using fast_image_resize
8688
/// Can pass DynamicImage directly when "image" feature is enabled
87-
fn fast_resize(img: &DynamicImage, new_w: u32, new_h: u32) -> DynamicImage {
89+
fn fast_resize(img: &DynamicImage, new_w: u32, new_h: u32) -> OcrResult<DynamicImage> {
8890
use fast_image_resize::{images::Image, IntoImageView, PixelType, Resizer};
8991

90-
// Get source image pixel type
91-
let pixel_type = img.pixel_type().unwrap_or(PixelType::U8x3);
92+
// Only U8x3 (RGB) and U8x4 (RGBA) are handled end-to-end.
93+
// Grayscale (U8), 16-bit, and other formats must be converted to RGB first;
94+
// otherwise the output buffer byte count would not match the expected channel count.
95+
let converted: DynamicImage;
96+
let (src, pixel_type) = match img.pixel_type() {
97+
Some(PixelType::U8x3) => (img, PixelType::U8x3),
98+
Some(PixelType::U8x4) => (img, PixelType::U8x4),
99+
_ => {
100+
converted = DynamicImage::ImageRgb8(img.to_rgb8());
101+
(&converted, PixelType::U8x3)
102+
}
103+
};
92104

93105
// Create destination image container
94106
let mut dst_image = Image::new(new_w, new_h, pixel_type);
95107

96-
// Resize using Resizer (pass DynamicImage directly, no manual conversion needed)
108+
// Resize using Resizer
97109
let mut resizer = Resizer::new();
98-
resizer.resize(img, &mut dst_image, None).unwrap();
110+
resizer
111+
.resize(src, &mut dst_image, None)
112+
.map_err(|e| OcrError::PreprocessError(format!("Image resize failed: {e}")))?;
99113

100114
// Convert result back to DynamicImage
101115
match pixel_type {
102-
PixelType::U8x3 => {
103-
DynamicImage::ImageRgb8(RgbImage::from_raw(new_w, new_h, dst_image.into_vec()).unwrap())
104-
}
105-
PixelType::U8x4 => DynamicImage::ImageRgba8(
106-
image::RgbaImage::from_raw(new_w, new_h, dst_image.into_vec()).unwrap(),
107-
),
108-
_ => {
109-
// Convert other types to RGB
110-
DynamicImage::ImageRgb8(RgbImage::from_raw(new_w, new_h, dst_image.into_vec()).unwrap())
111-
}
116+
PixelType::U8x3 => RgbImage::from_raw(new_w, new_h, dst_image.into_vec())
117+
.map(DynamicImage::ImageRgb8)
118+
.ok_or_else(|| {
119+
OcrError::PreprocessError("RGB buffer size mismatch after resize".into())
120+
}),
121+
PixelType::U8x4 => image::RgbaImage::from_raw(new_w, new_h, dst_image.into_vec())
122+
.map(DynamicImage::ImageRgba8)
123+
.ok_or_else(|| {
124+
OcrError::PreprocessError("RGBA buffer size mismatch after resize".into())
125+
}),
126+
_ => unreachable!("pixel_type is constrained to U8x3 or U8x4 above"),
112127
}
113128
}
114129

@@ -118,7 +133,7 @@ fn fast_resize(img: &DynamicImage, new_w: u32, new_h: u32) -> DynamicImage {
118133
pub fn preprocess_for_det(
119134
img: &DynamicImage,
120135
params: &NormalizeParams,
121-
) -> ArrayBase<OwnedRepr<f32>, Dim<[usize; 4]>> {
136+
) -> OcrResult<ArrayBase<OwnedRepr<f32>, Dim<[usize; 4]>>> {
122137
let (w, h) = img.dimensions();
123138
let pad_w = get_padded_size(w) as usize;
124139
let pad_h = get_padded_size(h) as usize;
@@ -138,7 +153,7 @@ pub fn preprocess_for_det(
138153
}
139154
}
140155

141-
input
156+
Ok(input)
142157
}
143158

144159
/// Convert image to recognition model input tensor
@@ -149,7 +164,7 @@ pub fn preprocess_for_rec(
149164
img: &DynamicImage,
150165
target_height: u32,
151166
params: &NormalizeParams,
152-
) -> ArrayBase<OwnedRepr<f32>, Dim<[usize; 4]>> {
167+
) -> OcrResult<ArrayBase<OwnedRepr<f32>, Dim<[usize; 4]>>> {
153168
let (w, h) = img.dimensions();
154169

155170
// Calculate scaled width
@@ -183,7 +198,7 @@ pub fn preprocess_for_rec(
183198
}
184199
}
185200

186-
input
201+
Ok(input)
187202
}
188203

189204
/// Batch preprocess recognition images
@@ -193,9 +208,9 @@ pub fn preprocess_batch_for_rec(
193208
images: &[DynamicImage],
194209
target_height: u32,
195210
params: &NormalizeParams,
196-
) -> ArrayBase<OwnedRepr<f32>, Dim<[usize; 4]>> {
211+
) -> OcrResult<ArrayBase<OwnedRepr<f32>, Dim<[usize; 4]>>> {
197212
if images.is_empty() {
198-
return Array4::<f32>::zeros((0, 3, target_height as usize, 0));
213+
return Ok(Array4::<f32>::zeros((0, 3, target_height as usize, 0)));
199214
}
200215

201216
// Calculate scaled width for all images
@@ -208,13 +223,14 @@ pub fn preprocess_batch_for_rec(
208223
})
209224
.collect();
210225

226+
// widths is non-empty because images is non-empty (checked above)
211227
let max_width = *widths.iter().max().unwrap() as usize;
212228
let batch_size = images.len();
213229

214230
let mut batch = Array4::<f32>::zeros((batch_size, 3, target_height as usize, max_width));
215231

216232
for (i, (img, &w)) in images.iter().zip(widths.iter()).enumerate() {
217-
let resized = resize_to_height(img, target_height);
233+
let resized = resize_to_height(img, target_height)?;
218234
let rgb_img = resized.to_rgb8();
219235

220236
for y in 0..target_height as usize {
@@ -229,7 +245,7 @@ pub fn preprocess_batch_for_rec(
229245
}
230246
}
231247

232-
batch
248+
Ok(batch)
233249
}
234250

235251
/// Crop image region
@@ -346,7 +362,7 @@ mod tests {
346362
#[test]
347363
fn test_resize_to_max_side_no_resize() {
348364
let img = DynamicImage::new_rgb8(100, 50);
349-
let resized = resize_to_max_side(&img, 200);
365+
let resized = resize_to_max_side(&img, 200).unwrap();
350366

351367
// 图像已经小于最大边,不应该缩放
352368
assert_eq!(resized.width(), 100);
@@ -356,7 +372,7 @@ mod tests {
356372
#[test]
357373
fn test_resize_to_max_side_width_limited() {
358374
let img = DynamicImage::new_rgb8(1000, 500);
359-
let resized = resize_to_max_side(&img, 500);
375+
let resized = resize_to_max_side(&img, 500).unwrap();
360376

361377
// 宽度是最大边,应该缩放到 500
362378
assert_eq!(resized.width(), 500);
@@ -366,7 +382,7 @@ mod tests {
366382
#[test]
367383
fn test_resize_to_max_side_height_limited() {
368384
let img = DynamicImage::new_rgb8(500, 1000);
369-
let resized = resize_to_max_side(&img, 500);
385+
let resized = resize_to_max_side(&img, 500).unwrap();
370386

371387
// 高度是最大边,应该缩放到 500
372388
assert_eq!(resized.width(), 250);
@@ -376,7 +392,7 @@ mod tests {
376392
#[test]
377393
fn test_resize_to_height() {
378394
let img = DynamicImage::new_rgb8(200, 100);
379-
let resized = resize_to_height(&img, 48);
395+
let resized = resize_to_height(&img, 48).unwrap();
380396

381397
assert_eq!(resized.height(), 48);
382398
// 宽度应该按比例缩放: 200 * 48/100 = 96
@@ -386,7 +402,7 @@ mod tests {
386402
#[test]
387403
fn test_resize_to_height_no_resize() {
388404
let img = DynamicImage::new_rgb8(200, 48);
389-
let resized = resize_to_height(&img, 48);
405+
let resized = resize_to_height(&img, 48).unwrap();
390406

391407
// 高度已经是目标高度,不应该缩放
392408
assert_eq!(resized.height(), 48);
@@ -397,7 +413,7 @@ mod tests {
397413
fn test_preprocess_for_det_shape() {
398414
let img = DynamicImage::new_rgb8(100, 50);
399415
let params = NormalizeParams::paddle_det();
400-
let tensor = preprocess_for_det(&img, &params);
416+
let tensor = preprocess_for_det(&img, &params).unwrap();
401417

402418
// 输出形状应该是 [1, 3, H, W],H 和 W 是 32 的倍数
403419
assert_eq!(tensor.shape()[0], 1);
@@ -410,7 +426,7 @@ mod tests {
410426
fn test_preprocess_for_rec_shape() {
411427
let img = DynamicImage::new_rgb8(200, 100);
412428
let params = NormalizeParams::paddle_rec();
413-
let tensor = preprocess_for_rec(&img, 48, &params);
429+
let tensor = preprocess_for_rec(&img, 48, &params).unwrap();
414430

415431
// 输出高度应该是 48
416432
assert_eq!(tensor.shape()[0], 1);
@@ -424,7 +440,7 @@ mod tests {
424440
fn test_preprocess_batch_for_rec_empty() {
425441
let images: Vec<DynamicImage> = vec![];
426442
let params = NormalizeParams::paddle_rec();
427-
let tensor = preprocess_batch_for_rec(&images, 48, &params);
443+
let tensor = preprocess_batch_for_rec(&images, 48, &params).unwrap();
428444

429445
assert_eq!(tensor.shape()[0], 0);
430446
}
@@ -433,7 +449,7 @@ mod tests {
433449
fn test_preprocess_batch_for_rec_single() {
434450
let images = vec![DynamicImage::new_rgb8(200, 100)];
435451
let params = NormalizeParams::paddle_rec();
436-
let tensor = preprocess_batch_for_rec(&images, 48, &params);
452+
let tensor = preprocess_batch_for_rec(&images, 48, &params).unwrap();
437453

438454
assert_eq!(tensor.shape()[0], 1);
439455
assert_eq!(tensor.shape()[1], 3);
@@ -447,7 +463,7 @@ mod tests {
447463
DynamicImage::new_rgb8(300, 100),
448464
];
449465
let params = NormalizeParams::paddle_rec();
450-
let tensor = preprocess_batch_for_rec(&images, 48, &params);
466+
let tensor = preprocess_batch_for_rec(&images, 48, &params).unwrap();
451467

452468
assert_eq!(tensor.shape()[0], 2);
453469
assert_eq!(tensor.shape()[1], 3);

0 commit comments

Comments
 (0)