2121 multi_return_metadata ,
2222 MultiReturn ,
2323 resnet18 ,
24+ resnet18_dynamo ,
2425 Simple ,
2526 )
2627except ImportError :
3031 multi_return_metadata ,
3132 MultiReturn ,
3233 resnet18 ,
34+ resnet18_dynamo ,
3335 Simple ,
3436 )
3537
@@ -60,34 +62,39 @@ def save(
6062 name ,
6163 model ,
6264 model_jit = None ,
65+ model_dynamo = None ,
6366 eg = None ,
6467 featurestore_meta = None ,
6568 text_in_extra_file = None ,
6669 binary_in_extra_file = None ,
6770):
68- with PackageExporter (str (p / name )) as e :
69- e .mock ("iopath.**" )
70- e .intern ("**" )
71- e .save_pickle ("model" , "model.pkl" , model )
72- if eg :
73- e .save_pickle ("model" , "example.pkl" , eg )
74- if featurestore_meta :
75- # TODO(whc) can this name come from buck somehow,
76- # so it's consistent with predictor_config_constants::METADATA_FILE_NAME()?
77- e .save_text ("extra_files" , "metadata.json" , featurestore_meta )
78- if text_in_extra_file :
79- e .save_text ("extra_files" , "text" , text_in_extra_file )
80- if binary_in_extra_file :
81- e .save_binary ("extra_files" , "binary" , binary_in_extra_file )
82-
71+ def package_model (name , model ):
72+ with PackageExporter (str (p / name )) as e :
73+ e .mock ("iopath.**" )
74+ e .intern ("**" )
75+ e .save_pickle ("model" , "model.pkl" , model )
76+ if eg :
77+ e .save_pickle ("model" , "example.pkl" , eg )
78+ if featurestore_meta :
79+ # TODO(whc) can this name come from buck somehow,
80+ # so it's consistent with predictor_config_constants::METADATA_FILE_NAME()?
81+ e .save_text ("extra_files" , "metadata.json" , featurestore_meta )
82+ if text_in_extra_file :
83+ e .save_text ("extra_files" , "text" , text_in_extra_file )
84+ if binary_in_extra_file :
85+ e .save_binary ("extra_files" , "binary" , binary_in_extra_file )
86+
87+ package_model (name , model )
88+ if model_dynamo :
89+ package_model (name + "_dynamo" , model_dynamo )
8390 if model_jit :
8491 model_jit .save (str (p / (name + "_jit" )))
92+
8593
8694
8795parser = argparse .ArgumentParser (description = "Generate Examples" )
8896parser .add_argument ("--install_dir" , help = "Root directory for all output files" )
8997
90-
9198if __name__ == "__main__" :
9299 args = parser .parse_args ()
93100 if args .install_dir is None :
@@ -98,9 +105,11 @@ def save(
98105
99106 resnet = resnet18 ()
100107 resnet .eval ()
108+ resnet_dynamo = resnet18_dynamo ()
101109 resnet_eg = torch .rand (1 , 3 , 224 , 224 )
102110 resnet_traced = torch .jit .trace (resnet , resnet_eg )
103- save ("resnet" , resnet , resnet_traced , (resnet_eg ,))
111+ save ("resnet" , resnet_dynamo , resnet_traced , resnet_dynamo , (resnet_eg ,))
112+ # save("resnet", resnet, resnet_traced, resnet_dynamo, (resnet_eg,))
104113
105114 simple = Simple (10 , 20 )
106115 save (
@@ -117,6 +126,7 @@ def save(
117126 "multi_return" ,
118127 multi_return ,
119128 torch .jit .script (multi_return ),
129+ None ,
120130 (torch .rand (10 , 20 ),),
121131 multi_return_metadata ,
122132 )
@@ -149,4 +159,4 @@ def save(
149159 e .add_dependency ("tensorrt" )
150160 e .mock ("iopath.**" )
151161 e .intern ("**" )
152- e .save_pickle ("make_trt_module" , "model.pkl" , make_trt_module )
162+ e .save_pickle ("make_trt_module" , "model.pkl" , make_trt_module )
0 commit comments