Skip to content

Commit 3070ac0

Browse files
committed
use policy replacement in benchmark
1 parent 472e4f8 commit 3070ac0

File tree

1 file changed

+62
-48
lines changed

1 file changed

+62
-48
lines changed

crates/cervo-cli/src/commands/benchmark.rs

Lines changed: 62 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -226,48 +226,74 @@ pub fn build_inputs_from_desc(
226226

227227
fn do_run(
228228
wrapper: impl ModelWrapper,
229-
inferer: impl Inferer,
230-
batch_size: usize,
229+
inferer: impl Inferer + 'static,
231230
config: &Args,
232-
) -> Result<Record> {
233-
let mut model = Model::new(wrapper, inferer);
231+
) -> Result<Vec<Record>> {
232+
let mut model = Model::new(wrapper, Box::new(inferer) as Box<dyn Inferer>);
234233

235-
let shapes = model.input_shapes().to_vec();
236-
let observations = build_inputs_from_desc(batch_size as u64, &shapes);
237-
for id in 0..batch_size {
238-
model.begin_agent(id as u64);
239-
}
240-
let res = execute_load_metrics(batch_size, observations, config.count, &mut model)?;
241-
for id in 0..batch_size {
242-
model.end_agent(id as u64);
234+
let mut records = Vec::with_capacity(config.batch_sizes.len());
235+
for batch_size in config.batch_sizes.clone() {
236+
let mut reader = File::open(&config.file)?;
237+
let inferer = if cervo::nnef::is_nnef_tar(&config.file) {
238+
cervo::nnef::builder(&mut reader).build_fixed(&[batch_size])?
239+
} else {
240+
match config.file.extension().and_then(|ext| ext.to_str()) {
241+
Some("onnx") => cervo::onnx::builder(&mut reader).build_fixed(&[batch_size])?,
242+
Some("crvo") => AssetData::deserialize(&mut reader)?.load_fixed(&[batch_size])?,
243+
Some(other) => bail!("unknown file type {:?}", other),
244+
None => bail!("missing file extension {:?}", config.file),
245+
}
246+
};
247+
248+
model = model
249+
.with_new_policy(Box::new(inferer) as Box<dyn Inferer>)
250+
.map_err(|(_, e)| e)?;
251+
252+
let shapes = model.input_shapes().to_vec();
253+
let observations = build_inputs_from_desc(batch_size as u64, &shapes);
254+
for id in 0..batch_size {
255+
model.begin_agent(id as u64);
256+
}
257+
let res = execute_load_metrics(batch_size, observations, config.count, &mut model)?;
258+
259+
// Print Text
260+
if matches!(config.output, OutputFormat::Text) {
261+
println!(
262+
"Batch Size {}: {:.2} ms ± {:.2} per element, {:.2} ms total",
263+
res.batch_size, res.mean, res.stddev, res.total,
264+
);
265+
}
266+
267+
records.push(res);
268+
for id in 0..batch_size {
269+
model.end_agent(id as u64);
270+
}
243271
}
244272

245-
Ok(res)
273+
Ok(records)
246274
}
247275

248276
fn run_apply_epsilon_config(
249277
wrapper: impl ModelWrapper,
250-
inferer: impl Inferer,
251-
batch_size: usize,
278+
inferer: impl Inferer + 'static,
252279
config: &Args,
253-
) -> Result<Record> {
280+
) -> Result<Vec<Record>> {
254281
if let Some(epsilon) = config.with_epsilon.as_ref() {
255282
let wrapper = EpsilonInjectorWrapper::wrap(wrapper, &inferer, epsilon)?;
256-
do_run(wrapper, inferer, batch_size, config)
283+
do_run(wrapper, inferer, config)
257284
} else {
258-
do_run(wrapper, inferer, batch_size, config)
285+
do_run(wrapper, inferer, config)
259286
}
260287
}
261288

262289
fn run_apply_recurrent(
263290
wrapper: impl ModelWrapper,
264-
inferer: impl Inferer,
265-
batch_size: usize,
291+
inferer: impl Inferer + 'static,
266292
config: &Args,
267-
) -> Result<Record> {
293+
) -> Result<Vec<Record>> {
268294
if let Some(recurrent) = config.recurrent.as_ref() {
269295
if matches!(recurrent, RecurrentConfig::None) {
270-
run_apply_epsilon_config(wrapper, inferer, batch_size, config)
296+
run_apply_epsilon_config(wrapper, inferer, config)
271297
} else {
272298
let wrapper = match recurrent {
273299
RecurrentConfig::None => unreachable!(),
@@ -282,40 +308,28 @@ fn run_apply_recurrent(
282308
}
283309
}?;
284310

285-
run_apply_epsilon_config(wrapper, inferer, batch_size, config)
311+
run_apply_epsilon_config(wrapper, inferer, config)
286312
}
287313
} else {
288-
run_apply_epsilon_config(wrapper, inferer, batch_size, config)
314+
run_apply_epsilon_config(wrapper, inferer, config)
289315
}
290316
}
291317

292318
pub(super) fn run(config: Args) -> Result<()> {
293-
let mut records: Vec<Record> = Vec::new();
294-
for batch_size in config.batch_sizes.clone() {
295-
let mut reader = File::open(&config.file)?;
296-
let inferer = if cervo::nnef::is_nnef_tar(&config.file) {
297-
cervo::nnef::builder(&mut reader).build_fixed(&[batch_size])?
298-
} else {
299-
match config.file.extension().and_then(|ext| ext.to_str()) {
300-
Some("onnx") => cervo::onnx::builder(&mut reader).build_fixed(&[batch_size])?,
301-
Some("crvo") => AssetData::deserialize(&mut reader)?.load_fixed(&[batch_size])?,
302-
Some(other) => bail!("unknown file type {:?}", other),
303-
None => bail!("missing file extension {:?}", config.file),
304-
}
305-
};
306-
307-
let record = run_apply_recurrent(BaseCase, inferer, batch_size, &config)?;
308-
309-
// Print Text
310-
if matches!(config.output, OutputFormat::Text) {
311-
println!(
312-
"Batch Size {}: {:.2} ms ± {:.2} per element, {:.2} ms total",
313-
record.batch_size, record.mean, record.stddev, record.total,
314-
);
319+
let mut reader = File::open(&config.file)?;
320+
let inferer = if cervo::nnef::is_nnef_tar(&config.file) {
321+
cervo::nnef::builder(&mut reader).build_basic()?
322+
} else {
323+
match config.file.extension().and_then(|ext| ext.to_str()) {
324+
Some("onnx") => cervo::onnx::builder(&mut reader).build_basic()?,
325+
Some("crvo") => AssetData::deserialize(&mut reader)?.load_basic()?,
326+
Some(other) => bail!("unknown file type {:?}", other),
327+
None => bail!("missing file extension {:?}", config.file),
315328
}
329+
};
330+
331+
let records = run_apply_recurrent(BaseCase, inferer, &config)?;
316332

317-
records.push(record);
318-
}
319333
// Print JSON
320334
if matches!(config.output, OutputFormat::Json) {
321335
let json = serde_json::to_string_pretty(&records)?;

0 commit comments

Comments
 (0)