Skip to content

Commit 67d84d3

Browse files
add ngc download bundle (#5710)
Signed-off-by: Yiheng Wang <[email protected]> Fixes #5679 and #5320 ### Description This PR adds the support of download bundles from ngc: https://catalog.ngc.nvidia.com/models?filters=&orderBy=scoreDESC&query=monai In addition, when "version" is not provided, it changes to download the latest version in default. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Yiheng Wang <[email protected]>
1 parent 0abd04e commit 67d84d3

File tree

4 files changed

+164
-49
lines changed

4 files changed

+164
-49
lines changed

.github/workflows/cron-mmar.yml renamed to .github/workflows/cron-ngc-bundle.yml

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
1-
# daily tests for clara mmar models
2-
name: cron-mmar
1+
# daily tests for ngc bundles
2+
name: cron-ngc-bundle
33

44
on:
5-
# schedule:
6-
# - cron: "0 2 * * *" # at 02:00 UTC
5+
schedule:
6+
- cron: "0 2 * * *" # at 02:00 UTC
77
# Allows you to run this workflow manually from the Actions tab
88
workflow_dispatch:
99

1010
concurrency:
1111
# automatically cancel the previously triggered workflows when there's a newer version
12-
group: mmar-tests-${{ github.event.pull_request.number || github.ref }}
12+
group: bundle-tests-${{ github.event.pull_request.number || github.ref }}
1313
cancel-in-progress: true
1414

1515
jobs:
@@ -33,12 +33,12 @@ jobs:
3333
key: ${{ runner.os }}-pip-${{ steps.pip-cache.outputs.datew }}
3434
- name: Install dependencies
3535
run: |
36-
rm -rf /github/home/.cache/torch/hub/mmars/
36+
rm -rf /github/home/.cache/torch/hub/bundle/
3737
python -m pip install --upgrade pip wheel
3838
python -m pip install -r requirements-dev.txt
39-
- name: Loading MMARs
39+
- name: Loading Bundles
4040
run: |
4141
# clean up temporary files
4242
$(pwd)/runtests.sh --build --clean
4343
# run tests
44-
python -m tests.ngc_mmar_loading
44+
python -m tests.ngc_bundle_download

monai/bundle/scripts.py

Lines changed: 110 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import torch
2626
from torch.cuda import is_available
2727

28+
from monai.apps.mmars.mmars import _get_all_ngc_models
2829
from monai.apps.utils import _basename, download_url, extractall, get_logger
2930
from monai.bundle.config_item import ConfigComponent
3031
from monai.bundle.config_parser import ConfigParser
@@ -42,6 +43,9 @@
4243

4344
logger = get_logger(module_name=__name__)
4445

46+
# set BUNDLE_DOWNLOAD_SRC="ngc" to use NGC source in default for bundle download
47+
download_source = os.environ.get("BUNDLE_DOWNLOAD_SRC", "github")
48+
4549

4650
def _update_args(args: Optional[Union[str, Dict]] = None, ignore_none: bool = True, **kwargs) -> Dict:
4751
"""
@@ -130,9 +134,11 @@ def _get_git_release_url(repo_owner: str, repo_name: str, tag_name: str, filenam
130134
return f"https://github.com/{repo_owner}/{repo_name}/releases/download/{tag_name}/{filename}"
131135

132136

137+
def _get_ngc_bundle_url(model_name: str, version: str):
138+
return f"https://api.ngc.nvidia.com/v2/models/nvidia/monaitoolkit/{model_name}/versions/{version}/zip"
139+
140+
133141
def _download_from_github(repo: str, download_path: Path, filename: str, progress: bool = True):
134-
if len(repo.split("/")) != 3:
135-
raise ValueError("if source is `github`, repo should be in the form of `repo_owner/repo_name/release_tag`.")
136142
repo_owner, repo_name, tag_name = repo.split("/")
137143
if ".zip" not in filename:
138144
filename += ".zip"
@@ -142,6 +148,45 @@ def _download_from_github(repo: str, download_path: Path, filename: str, progres
142148
extractall(filepath=filepath, output_dir=download_path, has_base=True)
143149

144150

151+
def _add_ngc_prefix(name: str, prefix: str = "monai_"):
152+
if name.startswith(prefix):
153+
return name
154+
return f"{prefix}{name}"
155+
156+
157+
def _remove_ngc_prefix(name: str, prefix: str = "monai_"):
158+
if name.startswith(prefix):
159+
return name[len(prefix) :]
160+
return name
161+
162+
163+
def _download_from_ngc(download_path: Path, filename: str, version: str, remove_prefix: Optional[str], progress: bool):
164+
# ensure prefix is contained
165+
filename = _add_ngc_prefix(filename)
166+
url = _get_ngc_bundle_url(model_name=filename, version=version)
167+
filepath = download_path / f"{filename}_v{version}.zip"
168+
if remove_prefix:
169+
filename = _remove_ngc_prefix(filename)
170+
extract_path = download_path / f"{filename}"
171+
download_url(url=url, filepath=filepath, hash_val=None, progress=progress)
172+
extractall(filepath=filepath, output_dir=extract_path, has_base=True)
173+
174+
175+
def _get_latest_bundle_version(source: str, name: str, repo: str):
176+
if source == "ngc":
177+
name = _add_ngc_prefix(name)
178+
model_dict = _get_all_ngc_models(name)
179+
for v in model_dict.values():
180+
if v["name"] == name:
181+
return v["latest"]
182+
return None
183+
elif source == "github":
184+
repo_owner, repo_name, tag_name = repo.split("/")
185+
return get_bundle_versions(name, repo=os.path.join(repo_owner, repo_name), tag=tag_name)["latest_version"]
186+
else:
187+
raise ValueError(f"To get the latest bundle version, source should be 'github' or 'ngc', got {source}.")
188+
189+
145190
def _process_bundle_dir(bundle_dir: Optional[PathLike] = None):
146191
if bundle_dir is None:
147192
get_dir, has_home = optional_import("torch.hub", name="get_dir")
@@ -156,9 +201,10 @@ def download(
156201
name: Optional[str] = None,
157202
version: Optional[str] = None,
158203
bundle_dir: Optional[PathLike] = None,
159-
source: str = "github",
160-
repo: str = "Project-MONAI/model-zoo/hosting_storage_v1",
204+
source: str = download_source,
205+
repo: Optional[str] = None,
161206
url: Optional[str] = None,
207+
remove_prefix: Optional[str] = "monai_",
162208
progress: bool = True,
163209
args_file: Optional[str] = None,
164210
):
@@ -175,9 +221,12 @@ def download(
175221
# Execute this module as a CLI entry, and download bundle from the model-zoo repo:
176222
python -m monai.bundle download --name <bundle_name> --version "0.1.0" --bundle_dir "./"
177223
178-
# Execute this module as a CLI entry, and download bundle:
224+
# Execute this module as a CLI entry, and download bundle from specified github repo:
179225
python -m monai.bundle download --name <bundle_name> --source "github" --repo "repo_owner/repo_name/release_tag"
180226
227+
# Execute this module as a CLI entry, and download bundle from ngc with latest version:
228+
python -m monai.bundle download --name <bundle_name> --source "ngc" --bundle_dir "./"
229+
181230
# Execute this module as a CLI entry, and download bundle via URL:
182231
python -m monai.bundle download --name <bundle_name> --url <url>
183232
@@ -190,18 +239,27 @@ def download(
190239
191240
Args:
192241
name: bundle name. If `None` and `url` is `None`, it must be provided in `args_file`.
193-
for example: "spleen_ct_segmentation", "prostate_mri_anatomy" in the model-zoo:
242+
for example:
243+
"spleen_ct_segmentation", "prostate_mri_anatomy" in model-zoo:
194244
https://github.com/Project-MONAI/model-zoo/releases/tag/hosting_storage_v1.
195-
version: version name of the target bundle to download, like: "0.1.0".
245+
"monai_brats_mri_segmentation" in ngc:
246+
https://catalog.ngc.nvidia.com/models?filters=&orderBy=scoreDESC&query=monai.
247+
version: version name of the target bundle to download, like: "0.1.0". If `None`, will download
248+
the latest version.
196249
bundle_dir: target directory to store the downloaded data.
197250
Default is `bundle` subfolder under `torch.hub.get_dir()`.
198251
source: storage location name. This argument is used when `url` is `None`.
199-
"github" is currently the only supported value.
200-
repo: repo name. This argument is used when `url` is `None`.
201-
If `source` is "github", it should be in the form of "repo_owner/repo_name/release_tag".
252+
In default, the value is achieved from the environment variable BUNDLE_DOWNLOAD_SRC, and
253+
it should be "ngc" or "github".
254+
repo: repo name. This argument is used when `url` is `None` and `source` is "github".
255+
If used, it should be in the form of "repo_owner/repo_name/release_tag".
202256
url: url to download the data. If not `None`, data will be downloaded directly
203257
and `source` will not be checked.
204258
If `name` is `None`, filename is determined by `monai.apps.utils._basename(url)`.
259+
remove_prefix: This argument is used when `source` is "ngc". Currently, all ngc bundles
260+
have the ``monai_`` prefix, which is not existing in their model zoo contrasts. In order to
261+
maintain the consistency between these two sources, remove prefix is necessary.
262+
Therefore, if specified, downloaded folder name will remove the prefix.
205263
progress: whether to display a progress bar.
206264
args_file: a JSON or YAML file to provide default values for all the args in this function.
207265
so that the command line inputs can be simplified.
@@ -215,17 +273,20 @@ def download(
215273
source=source,
216274
repo=repo,
217275
url=url,
276+
remove_prefix=remove_prefix,
218277
progress=progress,
219278
)
220279

221280
_log_input_summary(tag="download", args=_args)
222-
source_, repo_, progress_, name_, version_, bundle_dir_, url_ = _pop_args(
223-
_args, "source", "repo", "progress", name=None, version=None, bundle_dir=None, url=None
281+
source_, progress_, remove_prefix_, repo_, name_, version_, bundle_dir_, url_ = _pop_args(
282+
_args, "source", "progress", remove_prefix=None, repo=None, name=None, version=None, bundle_dir=None, url=None
224283
)
225284

226285
bundle_dir_ = _process_bundle_dir(bundle_dir_)
227-
if name_ is not None and version_ is not None:
228-
name_ = "_v".join([name_, version_])
286+
if repo_ is None:
287+
repo_ = "Project-MONAI/model-zoo/hosting_storage_v1"
288+
if len(repo_.split("/")) != 3:
289+
raise ValueError("repo should be in the form of `repo_owner/repo_name/release_tag`.")
229290

230291
if url_ is not None:
231292
if name_ is not None:
@@ -234,14 +295,27 @@ def download(
234295
filepath = bundle_dir_ / f"{_basename(url_)}"
235296
download_url(url=url_, filepath=filepath, hash_val=None, progress=progress_)
236297
extractall(filepath=filepath, output_dir=bundle_dir_, has_base=True)
237-
elif source_ == "github":
238-
if name_ is None:
239-
raise ValueError(f"To download from source: Github, `name` must be provided, got {name_}.")
240-
_download_from_github(repo=repo_, download_path=bundle_dir_, filename=name_, progress=progress_)
241298
else:
242-
raise NotImplementedError(
243-
f"Currently only download from provided URL in `url` or Github is implemented, got source: {source_}."
244-
)
299+
if name_ is None:
300+
raise ValueError(f"To download from source: {source_}, `name` must be provided.")
301+
if version_ is None:
302+
version_ = _get_latest_bundle_version(source=source_, name=name_, repo=repo_)
303+
if source_ == "github":
304+
if version_ is not None:
305+
name_ = "_v".join([name_, version_])
306+
_download_from_github(repo=repo_, download_path=bundle_dir_, filename=name_, progress=progress_)
307+
elif source_ == "ngc":
308+
_download_from_ngc(
309+
download_path=bundle_dir_,
310+
filename=name_,
311+
version=version_,
312+
remove_prefix=remove_prefix_,
313+
progress=progress_,
314+
)
315+
else:
316+
raise NotImplementedError(
317+
f"Currently only download from `url`, source 'github' or 'ngc' are implemented, got source: {source_}."
318+
)
245319

246320

247321
def load(
@@ -250,8 +324,8 @@ def load(
250324
model_file: Optional[str] = None,
251325
load_ts_module: bool = False,
252326
bundle_dir: Optional[PathLike] = None,
253-
source: str = "github",
254-
repo: str = "Project-MONAI/model-zoo/hosting_storage_v1",
327+
source: str = download_source,
328+
repo: Optional[str] = None,
255329
progress: bool = True,
256330
device: Optional[str] = None,
257331
key_in_ckpt: Optional[str] = None,
@@ -263,18 +337,25 @@ def load(
263337
Load model weights or TorchScript module of a bundle.
264338
265339
Args:
266-
name: bundle name, for example: "spleen_ct_segmentation", "prostate_mri_anatomy" in the model-zoo:
340+
name: bundle name. If `None` and `url` is `None`, it must be provided in `args_file`.
341+
for example:
342+
"spleen_ct_segmentation", "prostate_mri_anatomy" in model-zoo:
267343
https://github.com/Project-MONAI/model-zoo/releases/tag/hosting_storage_v1.
268-
version: version name of the target bundle to download, like: "0.1.0".
344+
"monai_brats_mri_segmentation" in ngc:
345+
https://catalog.ngc.nvidia.com/models?filters=&orderBy=scoreDESC&query=monai.
346+
version: version name of the target bundle to download, like: "0.1.0". If `None`, will download
347+
the latest version.
269348
model_file: the relative path of the model weights or TorchScript module within bundle.
270349
If `None`, "models/model.pt" or "models/model.ts" will be used.
271350
load_ts_module: a flag to specify if loading the TorchScript module.
272351
bundle_dir: directory the weights/TorchScript module will be loaded from.
273352
Default is `bundle` subfolder under `torch.hub.get_dir()`.
274353
source: storage location name. This argument is used when `model_file` is not existing locally and need to be
275-
downloaded first. "github" is currently the only supported value.
276-
repo: repo name. This argument is used when `model_file` is not existing locally and need to be
277-
downloaded first. If `source` is "github", it should be in the form of "repo_owner/repo_name/release_tag".
354+
downloaded first.
355+
In default, the value is achieved from the environment variable BUNDLE_DOWNLOAD_SRC, and
356+
it should be "ngc" or "github".
357+
repo: repo name. This argument is used when `url` is `None` and `source` is "github".
358+
If used, it should be in the form of "repo_owner/repo_name/release_tag".
278359
progress: whether to display a progress bar when downloading.
279360
device: target device of returned weights or module, if `None`, prefer to "cuda" if existing.
280361
key_in_ckpt: for nested checkpoint like `{"model": XXX, "optimizer": XXX, ...}`, specify the key of model
@@ -421,7 +502,7 @@ def get_bundle_versions(
421502

422503
bundles_info = _get_all_bundles_info(repo=repo, tag=tag, auth_token=auth_token)
423504
if bundle_name not in bundles_info:
424-
raise ValueError(f"bundle: {bundle_name} is not existing.")
505+
raise ValueError(f"bundle: {bundle_name} is not existing in repo: {repo}.")
425506
bundle_info = bundles_info[bundle_name]
426507
all_versions = sorted(bundle_info.keys())
427508

tests/ngc_mmar_loading.py renamed to tests/ngc_bundle_download.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,50 @@
1111

1212
import os
1313
import sys
14+
import tempfile
1415
import unittest
1516

1617
import torch
1718
from parameterized import parameterized
1819

20+
from monai.apps import check_hash
1921
from monai.apps.mmars import MODEL_DESC, load_from_mmar
22+
from monai.bundle import download
2023
from monai.config import print_debug_info
2124
from monai.networks.utils import copy_model_state
25+
from tests.utils import skip_if_downloading_fails, skip_if_quick, skip_if_windows
26+
27+
TEST_CASE_NGC_1 = [
28+
"spleen_ct_segmentation",
29+
"0.3.7",
30+
None,
31+
"monai_spleen_ct_segmentation",
32+
"models/model.pt",
33+
"b418a2dc8672ce2fd98dc255036e7a3d",
34+
]
35+
TEST_CASE_NGC_2 = [
36+
"monai_spleen_ct_segmentation",
37+
"0.3.7",
38+
"monai_",
39+
"spleen_ct_segmentation",
40+
"models/model.pt",
41+
"b418a2dc8672ce2fd98dc255036e7a3d",
42+
]
43+
44+
45+
@skip_if_windows
46+
class TestNgcBundleDownload(unittest.TestCase):
47+
@parameterized.expand([TEST_CASE_NGC_1, TEST_CASE_NGC_2])
48+
@skip_if_quick
49+
def test_ngc_download_bundle(self, bundle_name, version, remove_prefix, download_name, file_path, hash_val):
50+
with skip_if_downloading_fails():
51+
with tempfile.TemporaryDirectory() as tempdir:
52+
download(
53+
name=bundle_name, source="ngc", version=version, bundle_dir=tempdir, remove_prefix=remove_prefix
54+
)
55+
full_file_path = os.path.join(tempdir, download_name, file_path)
56+
self.assertTrue(os.path.exists(full_file_path))
57+
self.assertTrue(check_hash(filepath=full_file_path, val=hash_val))
2258

2359

2460
@unittest.skip("deprecating mmar tests")

0 commit comments

Comments
 (0)