Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions rust/cli/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

### Patch

- Add internal `_time` feature for inference time measurement
- Enable full LTO for the release profile

## 1.0.1
Expand Down
4 changes: 4 additions & 0 deletions rust/cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ serde = { version = "1.0.204", features = ["derive"] }
serde_json = "1.0.120"
tokio = { version = "1.43.1", features = ["full"] }

[features]
# Internal feature to measure inference time.
_time = []

[profile.release]
codegen-units = 1
lto = true
Expand Down
38 changes: 36 additions & 2 deletions rust/cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ use std::fmt::Write;
use std::io::ErrorKind;
use std::path::{Path, PathBuf};
use std::sync::Arc;
#[cfg(feature = "_time")]
use std::time::Instant;

use anyhow::{bail, ensure, Result};
use clap::{Args, Parser};
Expand Down Expand Up @@ -213,10 +215,14 @@ async fn main() -> Result<()> {
}
let mut reorder = Reorder::default();
let mut errors = false;
#[cfg(feature = "_time")]
let mut time = (0., 0.);
while let Some(response) = result_receiver.recv().await {
reorder.push(response?);
while let Some(response) = reorder.pop() {
errors |= response.result.is_err();
#[cfg(feature = "_time")]
(time = (time.0 + response.time, time.1 + 1.));
if flags.format.json {
if reorder.next != 1 {
print!(",");
Expand All @@ -236,6 +242,8 @@ async fn main() -> Result<()> {
}
println!("]");
}
#[cfg(feature = "_time")]
println!("Average inference time: {:.2}ms", 1000. * time.0 / time.1);
if errors {
std::process::exit(1);
}
Expand All @@ -260,7 +268,17 @@ async fn extract_features(
Ok(ProcessPath::Features(x)) => features.push(x),
};
match result {
Some(result) => result_sender.send(Ok(Response { order, path, result })).await?,
Some(result) => {
result_sender
.send(Ok(Response {
order,
path,
result,
#[cfg(feature = "_time")]
time: 0.,
}))
.await?
}
None => paths.push((order, path)),
}
order += 1;
Expand Down Expand Up @@ -360,11 +378,23 @@ async fn infer_batch(
sender: &tokio::sync::mpsc::Sender<Result<Response>>,
) -> Result<()> {
while let Ok(Batch { paths, features }) = receiver.recv().await {
#[cfg(feature = "_time")]
let start = Instant::now();
let batch = magika.identify_features_batch_async(&features).await?;
#[cfg(feature = "_time")]
let time = Instant::now().duration_since(start).as_secs_f32();
assert_eq!(batch.len(), paths.len());
for ((order, path), output) in paths.into_iter().zip(batch.into_iter()) {
let result = Ok(output);
sender.send(Ok(Response { order, path, result })).await?;
sender
.send(Ok(Response {
order,
path,
result,
#[cfg(feature = "_time")]
time,
}))
.await?;
}
}
Ok(())
Expand Down Expand Up @@ -404,6 +434,8 @@ struct Response {
order: usize,
path: PathBuf,
result: Result<FileType, magika::Error>,
#[cfg(feature = "_time")]
time: f32,
}

#[derive(Serialize)]
Expand Down Expand Up @@ -483,6 +515,8 @@ impl Response {
)?;
}
}
#[cfg(feature = "_time")]
Some('t') => write!(&mut result, "{:.2}ms", 1000. * self.time)?,
Some(c) => result.push(c),
None => break,
},
Expand Down
1 change: 1 addition & 0 deletions rust/cli/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ set -e
. ../color.sh

x cargo check
x cargo check --features=_time
x cargo build --release
x cargo fmt -- --check
x cargo clippy -- --deny=warnings
Expand Down
Loading