@@ -226,48 +226,74 @@ pub fn build_inputs_from_desc(
226226
227227fn do_run (
228228 wrapper : impl ModelWrapper ,
229- inferer : impl Inferer ,
230- batch_size : usize ,
229+ inferer : impl Inferer + ' static ,
231230 config : & Args ,
232- ) -> Result < Record > {
233- let mut model = Model :: new ( wrapper, inferer) ;
231+ ) -> Result < Vec < Record > > {
232+ let mut model = Model :: new ( wrapper, Box :: new ( inferer) as Box < dyn Inferer > ) ;
234233
235- let shapes = model. input_shapes ( ) . to_vec ( ) ;
236- let observations = build_inputs_from_desc ( batch_size as u64 , & shapes) ;
237- for id in 0 ..batch_size {
238- model. begin_agent ( id as u64 ) ;
239- }
240- let res = execute_load_metrics ( batch_size, observations, config. count , & mut model) ?;
241- for id in 0 ..batch_size {
242- model. end_agent ( id as u64 ) ;
234+ let mut records = Vec :: with_capacity ( config. batch_sizes . len ( ) ) ;
235+ for batch_size in config. batch_sizes . clone ( ) {
236+ let mut reader = File :: open ( & config. file ) ?;
237+ let inferer = if cervo:: nnef:: is_nnef_tar ( & config. file ) {
238+ cervo:: nnef:: builder ( & mut reader) . build_fixed ( & [ batch_size] ) ?
239+ } else {
240+ match config. file . extension ( ) . and_then ( |ext| ext. to_str ( ) ) {
241+ Some ( "onnx" ) => cervo:: onnx:: builder ( & mut reader) . build_fixed ( & [ batch_size] ) ?,
242+ Some ( "crvo" ) => AssetData :: deserialize ( & mut reader) ?. load_fixed ( & [ batch_size] ) ?,
243+ Some ( other) => bail ! ( "unknown file type {:?}" , other) ,
244+ None => bail ! ( "missing file extension {:?}" , config. file) ,
245+ }
246+ } ;
247+
248+ model = model
249+ . with_new_policy ( Box :: new ( inferer) as Box < dyn Inferer > )
250+ . map_err ( |( _, e) | e) ?;
251+
252+ let shapes = model. input_shapes ( ) . to_vec ( ) ;
253+ let observations = build_inputs_from_desc ( batch_size as u64 , & shapes) ;
254+ for id in 0 ..batch_size {
255+ model. begin_agent ( id as u64 ) ;
256+ }
257+ let res = execute_load_metrics ( batch_size, observations, config. count , & mut model) ?;
258+
259+ // Print Text
260+ if matches ! ( config. output, OutputFormat :: Text ) {
261+ println ! (
262+ "Batch Size {}: {:.2} ms ± {:.2} per element, {:.2} ms total" ,
263+ res. batch_size, res. mean, res. stddev, res. total,
264+ ) ;
265+ }
266+
267+ records. push ( res) ;
268+ for id in 0 ..batch_size {
269+ model. end_agent ( id as u64 ) ;
270+ }
243271 }
244272
245- Ok ( res )
273+ Ok ( records )
246274}
247275
248276fn run_apply_epsilon_config (
249277 wrapper : impl ModelWrapper ,
250- inferer : impl Inferer ,
251- batch_size : usize ,
278+ inferer : impl Inferer + ' static ,
252279 config : & Args ,
253- ) -> Result < Record > {
280+ ) -> Result < Vec < Record > > {
254281 if let Some ( epsilon) = config. with_epsilon . as_ref ( ) {
255282 let wrapper = EpsilonInjectorWrapper :: wrap ( wrapper, & inferer, epsilon) ?;
256- do_run ( wrapper, inferer, batch_size , config)
283+ do_run ( wrapper, inferer, config)
257284 } else {
258- do_run ( wrapper, inferer, batch_size , config)
285+ do_run ( wrapper, inferer, config)
259286 }
260287}
261288
262289fn run_apply_recurrent (
263290 wrapper : impl ModelWrapper ,
264- inferer : impl Inferer ,
265- batch_size : usize ,
291+ inferer : impl Inferer + ' static ,
266292 config : & Args ,
267- ) -> Result < Record > {
293+ ) -> Result < Vec < Record > > {
268294 if let Some ( recurrent) = config. recurrent . as_ref ( ) {
269295 if matches ! ( recurrent, RecurrentConfig :: None ) {
270- run_apply_epsilon_config ( wrapper, inferer, batch_size , config)
296+ run_apply_epsilon_config ( wrapper, inferer, config)
271297 } else {
272298 let wrapper = match recurrent {
273299 RecurrentConfig :: None => unreachable ! ( ) ,
@@ -282,40 +308,28 @@ fn run_apply_recurrent(
282308 }
283309 } ?;
284310
285- run_apply_epsilon_config ( wrapper, inferer, batch_size , config)
311+ run_apply_epsilon_config ( wrapper, inferer, config)
286312 }
287313 } else {
288- run_apply_epsilon_config ( wrapper, inferer, batch_size , config)
314+ run_apply_epsilon_config ( wrapper, inferer, config)
289315 }
290316}
291317
292318pub ( super ) fn run ( config : Args ) -> Result < ( ) > {
293- let mut records: Vec < Record > = Vec :: new ( ) ;
294- for batch_size in config. batch_sizes . clone ( ) {
295- let mut reader = File :: open ( & config. file ) ?;
296- let inferer = if cervo:: nnef:: is_nnef_tar ( & config. file ) {
297- cervo:: nnef:: builder ( & mut reader) . build_fixed ( & [ batch_size] ) ?
298- } else {
299- match config. file . extension ( ) . and_then ( |ext| ext. to_str ( ) ) {
300- Some ( "onnx" ) => cervo:: onnx:: builder ( & mut reader) . build_fixed ( & [ batch_size] ) ?,
301- Some ( "crvo" ) => AssetData :: deserialize ( & mut reader) ?. load_fixed ( & [ batch_size] ) ?,
302- Some ( other) => bail ! ( "unknown file type {:?}" , other) ,
303- None => bail ! ( "missing file extension {:?}" , config. file) ,
304- }
305- } ;
306-
307- let record = run_apply_recurrent ( BaseCase , inferer, batch_size, & config) ?;
308-
309- // Print Text
310- if matches ! ( config. output, OutputFormat :: Text ) {
311- println ! (
312- "Batch Size {}: {:.2} ms ± {:.2} per element, {:.2} ms total" ,
313- record. batch_size, record. mean, record. stddev, record. total,
314- ) ;
319+ let mut reader = File :: open ( & config. file ) ?;
320+ let inferer = if cervo:: nnef:: is_nnef_tar ( & config. file ) {
321+ cervo:: nnef:: builder ( & mut reader) . build_basic ( ) ?
322+ } else {
323+ match config. file . extension ( ) . and_then ( |ext| ext. to_str ( ) ) {
324+ Some ( "onnx" ) => cervo:: onnx:: builder ( & mut reader) . build_basic ( ) ?,
325+ Some ( "crvo" ) => AssetData :: deserialize ( & mut reader) ?. load_basic ( ) ?,
326+ Some ( other) => bail ! ( "unknown file type {:?}" , other) ,
327+ None => bail ! ( "missing file extension {:?}" , config. file) ,
315328 }
329+ } ;
330+
331+ let records = run_apply_recurrent ( BaseCase , inferer, & config) ?;
316332
317- records. push ( record) ;
318- }
319333 // Print JSON
320334 if matches ! ( config. output, OutputFormat :: Json ) {
321335 let json = serde_json:: to_string_pretty ( & records) ?;
0 commit comments