25
25
import torch
26
26
from torch .cuda import is_available
27
27
28
+ from monai .apps .mmars .mmars import _get_all_ngc_models
28
29
from monai .apps .utils import _basename , download_url , extractall , get_logger
29
30
from monai .bundle .config_item import ConfigComponent
30
31
from monai .bundle .config_parser import ConfigParser
42
43
43
44
logger = get_logger (module_name = __name__ )
44
45
46
+ # set BUNDLE_DOWNLOAD_SRC="ngc" to use NGC source in default for bundle download
47
+ download_source = os .environ .get ("BUNDLE_DOWNLOAD_SRC" , "github" )
48
+
45
49
46
50
def _update_args (args : Optional [Union [str , Dict ]] = None , ignore_none : bool = True , ** kwargs ) -> Dict :
47
51
"""
@@ -130,9 +134,11 @@ def _get_git_release_url(repo_owner: str, repo_name: str, tag_name: str, filenam
130
134
return f"https://github.com/{ repo_owner } /{ repo_name } /releases/download/{ tag_name } /{ filename } "
131
135
132
136
137
+ def _get_ngc_bundle_url (model_name : str , version : str ):
138
+ return f"https://api.ngc.nvidia.com/v2/models/nvidia/monaitoolkit/{ model_name } /versions/{ version } /zip"
139
+
140
+
133
141
def _download_from_github (repo : str , download_path : Path , filename : str , progress : bool = True ):
134
- if len (repo .split ("/" )) != 3 :
135
- raise ValueError ("if source is `github`, repo should be in the form of `repo_owner/repo_name/release_tag`." )
136
142
repo_owner , repo_name , tag_name = repo .split ("/" )
137
143
if ".zip" not in filename :
138
144
filename += ".zip"
@@ -142,6 +148,45 @@ def _download_from_github(repo: str, download_path: Path, filename: str, progres
142
148
extractall (filepath = filepath , output_dir = download_path , has_base = True )
143
149
144
150
151
+ def _add_ngc_prefix (name : str , prefix : str = "monai_" ):
152
+ if name .startswith (prefix ):
153
+ return name
154
+ return f"{ prefix } { name } "
155
+
156
+
157
+ def _remove_ngc_prefix (name : str , prefix : str = "monai_" ):
158
+ if name .startswith (prefix ):
159
+ return name [len (prefix ) :]
160
+ return name
161
+
162
+
163
+ def _download_from_ngc (download_path : Path , filename : str , version : str , remove_prefix : Optional [str ], progress : bool ):
164
+ # ensure prefix is contained
165
+ filename = _add_ngc_prefix (filename )
166
+ url = _get_ngc_bundle_url (model_name = filename , version = version )
167
+ filepath = download_path / f"{ filename } _v{ version } .zip"
168
+ if remove_prefix :
169
+ filename = _remove_ngc_prefix (filename )
170
+ extract_path = download_path / f"{ filename } "
171
+ download_url (url = url , filepath = filepath , hash_val = None , progress = progress )
172
+ extractall (filepath = filepath , output_dir = extract_path , has_base = True )
173
+
174
+
175
+ def _get_latest_bundle_version (source : str , name : str , repo : str ):
176
+ if source == "ngc" :
177
+ name = _add_ngc_prefix (name )
178
+ model_dict = _get_all_ngc_models (name )
179
+ for v in model_dict .values ():
180
+ if v ["name" ] == name :
181
+ return v ["latest" ]
182
+ return None
183
+ elif source == "github" :
184
+ repo_owner , repo_name , tag_name = repo .split ("/" )
185
+ return get_bundle_versions (name , repo = os .path .join (repo_owner , repo_name ), tag = tag_name )["latest_version" ]
186
+ else :
187
+ raise ValueError (f"To get the latest bundle version, source should be 'github' or 'ngc', got { source } ." )
188
+
189
+
145
190
def _process_bundle_dir (bundle_dir : Optional [PathLike ] = None ):
146
191
if bundle_dir is None :
147
192
get_dir , has_home = optional_import ("torch.hub" , name = "get_dir" )
@@ -156,9 +201,10 @@ def download(
156
201
name : Optional [str ] = None ,
157
202
version : Optional [str ] = None ,
158
203
bundle_dir : Optional [PathLike ] = None ,
159
- source : str = "github" ,
160
- repo : str = "Project-MONAI/model-zoo/hosting_storage_v1" ,
204
+ source : str = download_source ,
205
+ repo : Optional [ str ] = None ,
161
206
url : Optional [str ] = None ,
207
+ remove_prefix : Optional [str ] = "monai_" ,
162
208
progress : bool = True ,
163
209
args_file : Optional [str ] = None ,
164
210
):
@@ -175,9 +221,12 @@ def download(
175
221
# Execute this module as a CLI entry, and download bundle from the model-zoo repo:
176
222
python -m monai.bundle download --name <bundle_name> --version "0.1.0" --bundle_dir "./"
177
223
178
- # Execute this module as a CLI entry, and download bundle:
224
+ # Execute this module as a CLI entry, and download bundle from specified github repo :
179
225
python -m monai.bundle download --name <bundle_name> --source "github" --repo "repo_owner/repo_name/release_tag"
180
226
227
+ # Execute this module as a CLI entry, and download bundle from ngc with latest version:
228
+ python -m monai.bundle download --name <bundle_name> --source "ngc" --bundle_dir "./"
229
+
181
230
# Execute this module as a CLI entry, and download bundle via URL:
182
231
python -m monai.bundle download --name <bundle_name> --url <url>
183
232
@@ -190,18 +239,27 @@ def download(
190
239
191
240
Args:
192
241
name: bundle name. If `None` and `url` is `None`, it must be provided in `args_file`.
193
- for example: "spleen_ct_segmentation", "prostate_mri_anatomy" in the model-zoo:
242
+ for example:
243
+ "spleen_ct_segmentation", "prostate_mri_anatomy" in model-zoo:
194
244
https://github.com/Project-MONAI/model-zoo/releases/tag/hosting_storage_v1.
195
- version: version name of the target bundle to download, like: "0.1.0".
245
+ "monai_brats_mri_segmentation" in ngc:
246
+ https://catalog.ngc.nvidia.com/models?filters=&orderBy=scoreDESC&query=monai.
247
+ version: version name of the target bundle to download, like: "0.1.0". If `None`, will download
248
+ the latest version.
196
249
bundle_dir: target directory to store the downloaded data.
197
250
Default is `bundle` subfolder under `torch.hub.get_dir()`.
198
251
source: storage location name. This argument is used when `url` is `None`.
199
- "github" is currently the only supported value.
200
- repo: repo name. This argument is used when `url` is `None`.
201
- If `source` is "github", it should be in the form of "repo_owner/repo_name/release_tag".
252
+ In default, the value is achieved from the environment variable BUNDLE_DOWNLOAD_SRC, and
253
+ it should be "ngc" or "github".
254
+ repo: repo name. This argument is used when `url` is `None` and `source` is "github".
255
+ If used, it should be in the form of "repo_owner/repo_name/release_tag".
202
256
url: url to download the data. If not `None`, data will be downloaded directly
203
257
and `source` will not be checked.
204
258
If `name` is `None`, filename is determined by `monai.apps.utils._basename(url)`.
259
+ remove_prefix: This argument is used when `source` is "ngc". Currently, all ngc bundles
260
+ have the ``monai_`` prefix, which is not existing in their model zoo contrasts. In order to
261
+ maintain the consistency between these two sources, remove prefix is necessary.
262
+ Therefore, if specified, downloaded folder name will remove the prefix.
205
263
progress: whether to display a progress bar.
206
264
args_file: a JSON or YAML file to provide default values for all the args in this function.
207
265
so that the command line inputs can be simplified.
@@ -215,17 +273,20 @@ def download(
215
273
source = source ,
216
274
repo = repo ,
217
275
url = url ,
276
+ remove_prefix = remove_prefix ,
218
277
progress = progress ,
219
278
)
220
279
221
280
_log_input_summary (tag = "download" , args = _args )
222
- source_ , repo_ , progress_ , name_ , version_ , bundle_dir_ , url_ = _pop_args (
223
- _args , "source" , "repo " , "progress" , name = None , version = None , bundle_dir = None , url = None
281
+ source_ , progress_ , remove_prefix_ , repo_ , name_ , version_ , bundle_dir_ , url_ = _pop_args (
282
+ _args , "source" , "progress " , remove_prefix = None , repo = None , name = None , version = None , bundle_dir = None , url = None
224
283
)
225
284
226
285
bundle_dir_ = _process_bundle_dir (bundle_dir_ )
227
- if name_ is not None and version_ is not None :
228
- name_ = "_v" .join ([name_ , version_ ])
286
+ if repo_ is None :
287
+ repo_ = "Project-MONAI/model-zoo/hosting_storage_v1"
288
+ if len (repo_ .split ("/" )) != 3 :
289
+ raise ValueError ("repo should be in the form of `repo_owner/repo_name/release_tag`." )
229
290
230
291
if url_ is not None :
231
292
if name_ is not None :
@@ -234,14 +295,27 @@ def download(
234
295
filepath = bundle_dir_ / f"{ _basename (url_ )} "
235
296
download_url (url = url_ , filepath = filepath , hash_val = None , progress = progress_ )
236
297
extractall (filepath = filepath , output_dir = bundle_dir_ , has_base = True )
237
- elif source_ == "github" :
238
- if name_ is None :
239
- raise ValueError (f"To download from source: Github, `name` must be provided, got { name_ } ." )
240
- _download_from_github (repo = repo_ , download_path = bundle_dir_ , filename = name_ , progress = progress_ )
241
298
else :
242
- raise NotImplementedError (
243
- f"Currently only download from provided URL in `url` or Github is implemented, got source: { source_ } ."
244
- )
299
+ if name_ is None :
300
+ raise ValueError (f"To download from source: { source_ } , `name` must be provided." )
301
+ if version_ is None :
302
+ version_ = _get_latest_bundle_version (source = source_ , name = name_ , repo = repo_ )
303
+ if source_ == "github" :
304
+ if version_ is not None :
305
+ name_ = "_v" .join ([name_ , version_ ])
306
+ _download_from_github (repo = repo_ , download_path = bundle_dir_ , filename = name_ , progress = progress_ )
307
+ elif source_ == "ngc" :
308
+ _download_from_ngc (
309
+ download_path = bundle_dir_ ,
310
+ filename = name_ ,
311
+ version = version_ ,
312
+ remove_prefix = remove_prefix_ ,
313
+ progress = progress_ ,
314
+ )
315
+ else :
316
+ raise NotImplementedError (
317
+ f"Currently only download from `url`, source 'github' or 'ngc' are implemented, got source: { source_ } ."
318
+ )
245
319
246
320
247
321
def load (
@@ -250,8 +324,8 @@ def load(
250
324
model_file : Optional [str ] = None ,
251
325
load_ts_module : bool = False ,
252
326
bundle_dir : Optional [PathLike ] = None ,
253
- source : str = "github" ,
254
- repo : str = "Project-MONAI/model-zoo/hosting_storage_v1" ,
327
+ source : str = download_source ,
328
+ repo : Optional [ str ] = None ,
255
329
progress : bool = True ,
256
330
device : Optional [str ] = None ,
257
331
key_in_ckpt : Optional [str ] = None ,
@@ -263,18 +337,25 @@ def load(
263
337
Load model weights or TorchScript module of a bundle.
264
338
265
339
Args:
266
- name: bundle name, for example: "spleen_ct_segmentation", "prostate_mri_anatomy" in the model-zoo:
340
+ name: bundle name. If `None` and `url` is `None`, it must be provided in `args_file`.
341
+ for example:
342
+ "spleen_ct_segmentation", "prostate_mri_anatomy" in model-zoo:
267
343
https://github.com/Project-MONAI/model-zoo/releases/tag/hosting_storage_v1.
268
- version: version name of the target bundle to download, like: "0.1.0".
344
+ "monai_brats_mri_segmentation" in ngc:
345
+ https://catalog.ngc.nvidia.com/models?filters=&orderBy=scoreDESC&query=monai.
346
+ version: version name of the target bundle to download, like: "0.1.0". If `None`, will download
347
+ the latest version.
269
348
model_file: the relative path of the model weights or TorchScript module within bundle.
270
349
If `None`, "models/model.pt" or "models/model.ts" will be used.
271
350
load_ts_module: a flag to specify if loading the TorchScript module.
272
351
bundle_dir: directory the weights/TorchScript module will be loaded from.
273
352
Default is `bundle` subfolder under `torch.hub.get_dir()`.
274
353
source: storage location name. This argument is used when `model_file` is not existing locally and need to be
275
- downloaded first. "github" is currently the only supported value.
276
- repo: repo name. This argument is used when `model_file` is not existing locally and need to be
277
- downloaded first. If `source` is "github", it should be in the form of "repo_owner/repo_name/release_tag".
354
+ downloaded first.
355
+ In default, the value is achieved from the environment variable BUNDLE_DOWNLOAD_SRC, and
356
+ it should be "ngc" or "github".
357
+ repo: repo name. This argument is used when `url` is `None` and `source` is "github".
358
+ If used, it should be in the form of "repo_owner/repo_name/release_tag".
278
359
progress: whether to display a progress bar when downloading.
279
360
device: target device of returned weights or module, if `None`, prefer to "cuda" if existing.
280
361
key_in_ckpt: for nested checkpoint like `{"model": XXX, "optimizer": XXX, ...}`, specify the key of model
@@ -421,7 +502,7 @@ def get_bundle_versions(
421
502
422
503
bundles_info = _get_all_bundles_info (repo = repo , tag = tag , auth_token = auth_token )
423
504
if bundle_name not in bundles_info :
424
- raise ValueError (f"bundle: { bundle_name } is not existing." )
505
+ raise ValueError (f"bundle: { bundle_name } is not existing in repo: { repo } ." )
425
506
bundle_info = bundles_info [bundle_name ]
426
507
all_versions = sorted (bundle_info .keys ())
427
508
0 commit comments