@@ -216,47 +216,67 @@ def get_op_set(op_test_config, ops):
216216 help = "the version of the test case, (defaults: v1.0)" ,
217217 )
218218 parser .add_argument ("--ops" , nargs = "+" , type = str , help = "An array of ops" )
219+ parser .add_argument (
220+ "--models" , type = str , default = None , help = "Specify the model to test."
221+ )
219222 args = parser .parse_args ()
220223
221224 # load config
222225 config = toml .load (args .config )
223- op_test_config = config ["ops_test" ]
224226
225- # generate test cases
226- op_set = get_op_set (op_test_config , args .ops )
227+ if not args .models :
228+ op_test_config = config ["ops_test" ]
229+
230+ # generate test cases
231+ op_set = get_op_set (op_test_config , args .ops )
227232
228- for op_type in op_set :
229- pkg = importlib .import_module (op_test_config [op_type ]["package" ])
230- op_configs = op_test_config [op_type ]["cfg" ]
231- op_test_func = op_test_config [op_type ]["test_func" ]
232- quant_bits = op_test_config [op_type ].get ("quant_bits" , [])
233+ for op_type in op_set :
234+ pkg = importlib .import_module (op_test_config [op_type ]["package" ])
235+ op_configs = op_test_config [op_type ]["cfg" ]
236+ op_test_func = op_test_config [op_type ]["test_func" ]
237+ quant_bits = op_test_config [op_type ].get ("quant_bits" , [])
233238
234- if (args .bits == 8 and "int8" in quant_bits ) or (
235- args .bits == 16 and "int16" in quant_bits
236- ):
237- export_path = os .path .join (args .output_path , op_type )
238- for cfg in op_configs :
239+ if (args .bits == 8 and "int8" in quant_bits ) or (
240+ args .bits == 16 and "int16" in quant_bits
241+ ):
242+ export_path = os .path .join (args .output_path , op_type )
243+ for cfg in op_configs :
244+ print (
245+ "Op Test Function: " ,
246+ op_test_func ,
247+ "Configs: " ,
248+ cfg ,
249+ "Package: " ,
250+ pkg .__name__ ,
251+ "Output Path: " ,
252+ export_path ,
253+ )
254+ op = getattr (pkg , op_test_func )(cfg )
255+ BaseInferencer (
256+ op ,
257+ export_path = export_path ,
258+ model_cfg = cfg ,
259+ target = args .target ,
260+ num_of_bits = args .bits ,
261+ model_version = args .version ,
262+ meta_cfg = config ["meta" ],
263+ )()
264+ else :
239265 print (
240- "Op Test Function: " ,
241- op_test_func ,
242- "Configs: " ,
243- cfg ,
244- "Package: " ,
245- pkg .__name__ ,
246- "Output Path: " ,
247- export_path ,
266+ f"Skip op: { op_type } , do not support quantization with { args .bits } bits."
248267 )
249- op = getattr (pkg , op_test_func )(cfg )
250- BaseInferencer (
251- op ,
252- export_path = export_path ,
253- model_cfg = cfg ,
254- target = args .target ,
255- num_of_bits = args .bits ,
256- model_version = args .version ,
257- meta_cfg = config ["meta" ],
258- )()
268+ else :
269+ model_config = config ["models_test" ][args .models ]
270+ if args .bits == 8 or args .bits == 16 :
271+ model = onnx .load (model_config ["onnx_model_path" ])
272+ BaseInferencer (
273+ model ,
274+ export_path = args .output_path ,
275+ model_cfg = model_config ,
276+ target = args .target ,
277+ num_of_bits = args .bits ,
278+ model_version = args .version ,
279+ meta_cfg = config ["meta" ],
280+ )()
259281 else :
260- print (
261- f"Skip op: { op_type } , do not support quantization with { args .bits } bits."
262- )
282+ print (f"Do not support quantization with { args .bits } bits." )
0 commit comments