@@ -67,39 +67,42 @@ def postprocess(i):
6767
6868 env = i ['env' ]
6969
70- if os .path .exists (os .path .join (
71- env ['GPTJ_CHECKPOINT_PATH' ], "checkpoint-final" )):
72- env ['GPTJ_CHECKPOINT_PATH' ] = os .path .join (
73- env ['GPTJ_CHECKPOINT_PATH' ], "checkpoint-final" )
74-
75- is_saxml = env .get ('MLC_TMP_MODEL_SAXML' , '' )
76- if is_saxml == "fp32" :
77- if os .path .exists ("pax_gptj_checkpoint" ):
78- env ['GPTJ_SAXML_CHECKPOINT_PATH' ] = os .path .join (
79- os .getcwd (), "pax_gptj_checkpoint" )
80- env ['MLC_ML_MODEL_FILE_WITH_PATH' ] = env ['GPTJ_SAXML_CHECKPOINT_PATH' ]
81- else :
82- return {'return' : 1 , 'error' : 'pax_gptj_checkpoint generation failed' }
70+ if not env .get ('MLC_DOWNLOAD_MODE' , '' ) == "dry" :
71+ if os .path .exists (os .path .join (
72+ env ['GPTJ_CHECKPOINT_PATH' ], "checkpoint-final" )):
73+ env ['GPTJ_CHECKPOINT_PATH' ] = os .path .join (
74+ env ['GPTJ_CHECKPOINT_PATH' ], "checkpoint-final" )
75+
76+ is_saxml = env .get ('MLC_TMP_MODEL_SAXML' , '' )
77+ if is_saxml == "fp32" :
78+ if os .path .exists ("pax_gptj_checkpoint" ):
79+ env ['GPTJ_SAXML_CHECKPOINT_PATH' ] = os .path .join (
80+ os .getcwd (), "pax_gptj_checkpoint" )
81+ env ['MLC_ML_MODEL_FILE_WITH_PATH' ] = env ['GPTJ_SAXML_CHECKPOINT_PATH' ]
82+ else :
83+ return {'return' : 1 ,
84+ 'error' : 'pax_gptj_checkpoint generation failed' }
8385
84- elif is_saxml == "int8" :
85- if os .path .exists ("int8_ckpt" ):
86- env ['GPTJ_SAXML_INT8_CHECKPOINT_PATH' ] = os .path .join (
87- os .getcwd (), "int8_ckpt" )
88- env ['MLC_ML_MODEL_FILE_WITH_PATH' ] = env ['GPTJ_SAXML_INT8_CHECKPOINT_PATH' ]
86+ elif is_saxml == "int8" :
87+ if os .path .exists ("int8_ckpt" ):
88+ env ['GPTJ_SAXML_INT8_CHECKPOINT_PATH' ] = os .path .join (
89+ os .getcwd (), "int8_ckpt" )
90+ env ['MLC_ML_MODEL_FILE_WITH_PATH' ] = env ['GPTJ_SAXML_INT8_CHECKPOINT_PATH' ]
91+ else :
92+ return {'return' : 1 ,
93+ 'error' : 'pax_gptj_checkpoint generation failed' }
94+ elif env .get ('MLC_TMP_ML_MODEL_PROVIDER' , '' ) == 'nvidia' :
95+ env ['MLC_ML_MODEL_FILE_WITH_PATH' ] = os .path .join (
96+ env ['MLC_NVIDIA_MLPERF_SCRATCH_PATH' ],
97+ 'models' ,
98+ 'GPTJ-6B' ,
99+ 'fp8-quantized-ammo' ,
100+ 'GPTJ-FP8-quantized' )
89101 else :
90- return {'return' : 1 , 'error' : 'pax_gptj_checkpoint generation failed' }
91- elif env .get ('MLC_TMP_ML_MODEL_PROVIDER' , '' ) == 'nvidia' :
92- env ['MLC_ML_MODEL_FILE_WITH_PATH' ] = os .path .join (
93- env ['MLC_NVIDIA_MLPERF_SCRATCH_PATH' ],
94- 'models' ,
95- 'GPTJ-6B' ,
96- 'fp8-quantized-ammo' ,
97- 'GPTJ-FP8-quantized' )
98- else :
99- env ['MLC_ML_MODEL_FILE_WITH_PATH' ] = env ['GPTJ_CHECKPOINT_PATH' ]
102+ env ['MLC_ML_MODEL_FILE_WITH_PATH' ] = env ['GPTJ_CHECKPOINT_PATH' ]
100103
101- env ['MLC_ML_MODEL_FILE' ] = os .path .basename (
102- env ['MLC_ML_MODEL_FILE_WITH_PATH' ])
103- env ['MLC_GET_DEPENDENT_CACHED_PATH' ] = env ['MLC_ML_MODEL_FILE_WITH_PATH' ]
104+ env ['MLC_ML_MODEL_FILE' ] = os .path .basename (
105+ env ['MLC_ML_MODEL_FILE_WITH_PATH' ])
106+ env ['MLC_GET_DEPENDENT_CACHED_PATH' ] = env ['MLC_ML_MODEL_FILE_WITH_PATH' ]
104107
105108 return {'return' : 0 }
0 commit comments