diff --git a/rust/cli/CHANGELOG.md b/rust/cli/CHANGELOG.md index a198e08e..b3af0b90 100644 --- a/rust/cli/CHANGELOG.md +++ b/rust/cli/CHANGELOG.md @@ -4,6 +4,7 @@ ### Patch +- Add internal `_time` feature for inference time measurement - Enable full LTO for the release profile ## 1.0.1 diff --git a/rust/cli/Cargo.toml b/rust/cli/Cargo.toml index 63a2e301..e7f553c6 100644 --- a/rust/cli/Cargo.toml +++ b/rust/cli/Cargo.toml @@ -23,6 +23,10 @@ serde = { version = "1.0.204", features = ["derive"] } serde_json = "1.0.120" tokio = { version = "1.43.1", features = ["full"] } +[features] +# Internal feature to measure inference time. +_time = [] + [profile.release] codegen-units = 1 lto = true diff --git a/rust/cli/src/main.rs b/rust/cli/src/main.rs index 6f5c64bc..c21e2ffd 100644 --- a/rust/cli/src/main.rs +++ b/rust/cli/src/main.rs @@ -30,6 +30,8 @@ use ort::session::builder::GraphOptimizationLevel; use serde::Serialize; use tokio::fs::File; use tokio::io::AsyncReadExt; +#[cfg(feature = "_time")] +use tokio::time::Instant; /// Determines file content types using AI. #[derive(Parser)] @@ -172,6 +174,8 @@ async fn main() -> Result<()> { std::cmp::max(2, num_cpus::get_physical()) }); ensure!(0 < num_tasks, "--num-tasks cannot be zero"); + #[cfg(feature = "_time")] + ensure!(num_tasks == 1, "--num-tasks must be 1 for time measurements"); ensure!( flags.path.iter().filter(|x| x.to_str() == Some("-")).count() <= 1, "only one path can be the standard input" @@ -213,10 +217,14 @@ async fn main() -> Result<()> { } let mut reorder = Reorder::default(); let mut errors = false; + #[cfg(feature = "_time")] + let mut time = (0., 0.); while let Some(response) = result_receiver.recv().await { reorder.push(response?); while let Some(response) = reorder.pop() { errors |= response.result.is_err(); + #[cfg(feature = "_time")] + (time = (time.0 + response.time, time.1 + 1.)); if flags.format.json { if reorder.next != 1 { print!(","); @@ -236,6 +244,10 @@ async fn main() -> Result<()> { } println!("]"); } + #[cfg(feature = "_time")] + if 0. < time.1 { + println!("Average inference time: {:.2}ms", 1000. * time.0 / time.1); + } if errors { std::process::exit(1); } @@ -260,7 +272,17 @@ async fn extract_features( Ok(ProcessPath::Features(x)) => features.push(x), }; match result { - Some(result) => result_sender.send(Ok(Response { order, path, result })).await?, + Some(result) => { + result_sender + .send(Ok(Response { + order, + path, + result, + #[cfg(feature = "_time")] + time: 0., + })) + .await? + } None => paths.push((order, path)), } order += 1; @@ -360,11 +382,23 @@ async fn infer_batch( sender: &tokio::sync::mpsc::Sender>, ) -> Result<()> { while let Ok(Batch { paths, features }) = receiver.recv().await { + #[cfg(feature = "_time")] + let start = Instant::now(); let batch = magika.identify_features_batch_async(&features).await?; + #[cfg(feature = "_time")] + let time = Instant::now().duration_since(start).as_secs_f64() / paths.len() as f64; assert_eq!(batch.len(), paths.len()); for ((order, path), output) in paths.into_iter().zip(batch.into_iter()) { let result = Ok(output); - sender.send(Ok(Response { order, path, result })).await?; + sender + .send(Ok(Response { + order, + path, + result, + #[cfg(feature = "_time")] + time, + })) + .await?; } } Ok(()) @@ -404,6 +438,8 @@ struct Response { order: usize, path: PathBuf, result: Result, + #[cfg(feature = "_time")] + time: f64, } #[derive(Serialize)] @@ -483,6 +519,8 @@ impl Response { )?; } } + #[cfg(feature = "_time")] + Some('t') => write!(&mut result, "{:.2}ms", 1000. * self.time)?, Some(c) => result.push(c), None => break, }, diff --git a/rust/cli/test.sh b/rust/cli/test.sh index c583c86f..c1d71f4a 100755 --- a/rust/cli/test.sh +++ b/rust/cli/test.sh @@ -17,6 +17,7 @@ set -e . ../color.sh x cargo check +x cargo check --features=_time x cargo build --release x cargo fmt -- --check x cargo clippy -- --deny=warnings