Skip to content

Commit e62ad4c

Browse files
committed
refactor: simplify review fixes for bat detection code
- Use `if let Some` guard instead of redundant ok_or_else after is_some() check - Use .nth() instead of collecting all keys into Vec in extract_tensor_data - Remove hardcoded ModelType::BirdNetV24 from CustomClassifier label loading; use LabelFormat::Text directly for model-agnostic behavior - Remove unused ModelType import
1 parent d497e31 commit e62ad4c

2 files changed

Lines changed: 13 additions & 12 deletions

File tree

src/classifier.rs

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -992,12 +992,9 @@ impl Classifier {
992992
let num_species = self.inner.config.num_species;
993993

994994
match model_type {
995-
ModelType::BirdNetV24 if self.inner.config.embedding_dim.is_some() => {
996-
let embedding_dim = self.inner.config.embedding_dim.ok_or_else(|| {
997-
Error::Inference(
998-
"embedding_dim missing for v2.4 model with embeddings".into(),
999-
)
1000-
})?;
995+
ModelType::BirdNetV24
996+
if let Some(embedding_dim) = self.inner.config.embedding_dim =>
997+
{
1001998
let logits_flat = extract_tensor_data(outputs, 0)?;
1002999
let emb_flat = extract_tensor_data(outputs, 1)?;
10031000

src/custom_classifier.rs

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
use crate::error::{Error, Result};
44
use crate::labels;
55
use crate::postprocess::top_k_predictions;
6-
use crate::types::{ModelType, Prediction};
6+
use crate::types::Prediction;
77
use ndarray::Array2;
88
use ort::session::Session;
99
use ort::value::Value;
@@ -86,7 +86,11 @@ impl CustomClassifierBuilder {
8686
let input_dim = extract_last_dim(session.inputs(), "input")?;
8787
let num_classes = extract_last_dim(session.outputs(), "output")?;
8888

89-
let labels = labels::load_labels_from_file(&labels_path, ModelType::BirdNetV24)?;
89+
let content = std::fs::read_to_string(&labels_path).map_err(|e| Error::LabelLoad {
90+
path: labels_path.display().to_string(),
91+
reason: e.to_string(),
92+
})?;
93+
let labels = labels::parse_labels(&content, crate::types::LabelFormat::Text)?;
9094

9195
if labels.len() != num_classes {
9296
return Err(Error::LabelCount {
@@ -275,13 +279,13 @@ fn extract_tensor_data(
275279
index: usize,
276280
expected_len: usize,
277281
) -> Result<Vec<f32>> {
278-
let output_names: Vec<_> = outputs.keys().collect();
279-
let name = output_names
280-
.get(index)
282+
let name = outputs
283+
.keys()
284+
.nth(index)
281285
.ok_or_else(|| Error::Inference(format!("missing output tensor at index {index}")))?;
282286

283287
let tensor = outputs
284-
.get(*name)
288+
.get(name)
285289
.ok_or_else(|| Error::Inference(format!("missing output tensor '{name}'")))?;
286290

287291
let (_, data) = tensor

0 commit comments

Comments
 (0)