Skip to content

Commit 8bc886c

Browse files
authored
Add support for adding r2-downoader for GPT-j (#680)
* initial commit for supporting r2 downloader in gptj * fix download-tool group issue
1 parent 31b48fd commit 8bc886c

File tree

3 files changed

+66
-36
lines changed

3 files changed

+66
-36
lines changed

script/get-dataset-cnndm/meta.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ deps:
3535
skip_if_env:
3636
MLC_TMP_ML_MODEL:
3737
- llama3_1-8b
38+
- tags: get,generic-python-lib,_package.transformers
39+
skip_if_env:
40+
MLC_TMP_ML_MODEL:
41+
- llama3_1-8b
3842
- tags: get,generic-python-lib,_numpy
3943
skip_if_env:
4044
MLC_TMP_ML_MODEL:

script/get-ml-model-gptj/customize.py

Lines changed: 34 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -67,39 +67,42 @@ def postprocess(i):
6767

6868
env = i['env']
6969

70-
if os.path.exists(os.path.join(
71-
env['GPTJ_CHECKPOINT_PATH'], "checkpoint-final")):
72-
env['GPTJ_CHECKPOINT_PATH'] = os.path.join(
73-
env['GPTJ_CHECKPOINT_PATH'], "checkpoint-final")
74-
75-
is_saxml = env.get('MLC_TMP_MODEL_SAXML', '')
76-
if is_saxml == "fp32":
77-
if os.path.exists("pax_gptj_checkpoint"):
78-
env['GPTJ_SAXML_CHECKPOINT_PATH'] = os.path.join(
79-
os.getcwd(), "pax_gptj_checkpoint")
80-
env['MLC_ML_MODEL_FILE_WITH_PATH'] = env['GPTJ_SAXML_CHECKPOINT_PATH']
81-
else:
82-
return {'return': 1, 'error': 'pax_gptj_checkpoint generation failed'}
70+
if not env.get('MLC_DOWNLOAD_MODE', '') == "dry":
71+
if os.path.exists(os.path.join(
72+
env['GPTJ_CHECKPOINT_PATH'], "checkpoint-final")):
73+
env['GPTJ_CHECKPOINT_PATH'] = os.path.join(
74+
env['GPTJ_CHECKPOINT_PATH'], "checkpoint-final")
75+
76+
is_saxml = env.get('MLC_TMP_MODEL_SAXML', '')
77+
if is_saxml == "fp32":
78+
if os.path.exists("pax_gptj_checkpoint"):
79+
env['GPTJ_SAXML_CHECKPOINT_PATH'] = os.path.join(
80+
os.getcwd(), "pax_gptj_checkpoint")
81+
env['MLC_ML_MODEL_FILE_WITH_PATH'] = env['GPTJ_SAXML_CHECKPOINT_PATH']
82+
else:
83+
return {'return': 1,
84+
'error': 'pax_gptj_checkpoint generation failed'}
8385

84-
elif is_saxml == "int8":
85-
if os.path.exists("int8_ckpt"):
86-
env['GPTJ_SAXML_INT8_CHECKPOINT_PATH'] = os.path.join(
87-
os.getcwd(), "int8_ckpt")
88-
env['MLC_ML_MODEL_FILE_WITH_PATH'] = env['GPTJ_SAXML_INT8_CHECKPOINT_PATH']
86+
elif is_saxml == "int8":
87+
if os.path.exists("int8_ckpt"):
88+
env['GPTJ_SAXML_INT8_CHECKPOINT_PATH'] = os.path.join(
89+
os.getcwd(), "int8_ckpt")
90+
env['MLC_ML_MODEL_FILE_WITH_PATH'] = env['GPTJ_SAXML_INT8_CHECKPOINT_PATH']
91+
else:
92+
return {'return': 1,
93+
'error': 'pax_gptj_checkpoint generation failed'}
94+
elif env.get('MLC_TMP_ML_MODEL_PROVIDER', '') == 'nvidia':
95+
env['MLC_ML_MODEL_FILE_WITH_PATH'] = os.path.join(
96+
env['MLC_NVIDIA_MLPERF_SCRATCH_PATH'],
97+
'models',
98+
'GPTJ-6B',
99+
'fp8-quantized-ammo',
100+
'GPTJ-FP8-quantized')
89101
else:
90-
return {'return': 1, 'error': 'pax_gptj_checkpoint generation failed'}
91-
elif env.get('MLC_TMP_ML_MODEL_PROVIDER', '') == 'nvidia':
92-
env['MLC_ML_MODEL_FILE_WITH_PATH'] = os.path.join(
93-
env['MLC_NVIDIA_MLPERF_SCRATCH_PATH'],
94-
'models',
95-
'GPTJ-6B',
96-
'fp8-quantized-ammo',
97-
'GPTJ-FP8-quantized')
98-
else:
99-
env['MLC_ML_MODEL_FILE_WITH_PATH'] = env['GPTJ_CHECKPOINT_PATH']
102+
env['MLC_ML_MODEL_FILE_WITH_PATH'] = env['GPTJ_CHECKPOINT_PATH']
100103

101-
env['MLC_ML_MODEL_FILE'] = os.path.basename(
102-
env['MLC_ML_MODEL_FILE_WITH_PATH'])
103-
env['MLC_GET_DEPENDENT_CACHED_PATH'] = env['MLC_ML_MODEL_FILE_WITH_PATH']
104+
env['MLC_ML_MODEL_FILE'] = os.path.basename(
105+
env['MLC_ML_MODEL_FILE_WITH_PATH'])
106+
env['MLC_GET_DEPENDENT_CACHED_PATH'] = env['MLC_ML_MODEL_FILE_WITH_PATH']
104107

105108
return {'return': 0}

script/get-ml-model-gptj/meta.yaml

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ prehook_deps:
3737
tests:
3838
run_inputs:
3939
- variations_list:
40-
- fp32
40+
- fp32,pytorch,r2-downloader,dry-run
41+
- fp32,pytorch,rclone,dry-run
4142

4243

4344
print_env_at_the_end:
@@ -102,10 +103,15 @@ variations:
102103
env:
103104
MLC_DOWNLOAD_CHECKSUM_NOT_USED: e677e28aaf03da84584bb3073b7ee315
104105
MLC_PACKAGE_URL: https://cloud.mlcommons.org/index.php/s/QAZ2oM94MkFtbQx/download
105-
MLC_RCLONE_CONFIG_NAME: mlc-inference
106-
MLC_RCLONE_URL: mlc-inference:mlcommons-inference-wg-public/gpt-j
107106
MLC_UNZIP: 'yes'
108107
required_disk_space: 22700
108+
pytorch,fp32,rclone:
109+
env:
110+
MLC_PACKAGE_URL: mlc-inference:mlcommons-inference-wg-public/gpt-j
111+
MLC_RCLONE_CONFIG_NAME: mlc-inference
112+
pytorch,fp32,r2-downloader:
113+
env:
114+
MLC_DOWNLOAD_URL: https://inference.mlcommons-storage.org/metadata/gpt-j-model-checkpoint.uri
109115
pytorch,fp32,wget:
110116
add_deps_recursive:
111117
dae:
@@ -183,11 +189,28 @@ variations:
183189
add_deps_recursive:
184190
dae:
185191
tags: _rclone
186-
default: true
187192
env:
188193
MLC_DOWNLOAD_FILENAME: checkpoint
189-
MLC_DOWNLOAD_URL: <<<MLC_RCLONE_URL>>>
194+
MLC_DOWNLOAD_URL: <<<MLC_PACKAGE_URL>>>
190195
group: download-tool
196+
r2-downloader:
197+
group: download-tool
198+
default: true
199+
add_deps_recursive:
200+
dae:
201+
tags: _r2-downloader
202+
env:
203+
MLC_DOWNLOAD_FILENAME: checkpoint
204+
dry-run:
205+
group: run-mode
206+
env:
207+
MLC_DOWNLOAD_MODE: dry
208+
dry-run,r2-downloader:
209+
env:
210+
MLC_DOWNLOAD_EXTRA_OPTIONS: -x
211+
dry-run,rclone:
212+
env:
213+
MLC_DOWNLOAD_EXTRA_OPTIONS: --dry-run
191214
saxml:
192215
group: framework
193216
saxml,fp32:

0 commit comments

Comments
 (0)