Skip to content

Commit b159ce7

Browse files
5770 add remove prefix to monai.bundle.load (#5771)
Signed-off-by: Yiheng Wang <[email protected]> Fixes #5770 . ### Description This PR enables the `remove_prefix` arg in `monai.bundle.load` ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Yiheng Wang <[email protected]>
1 parent 8037fcd commit b159ce7

File tree

2 files changed

+51
-3
lines changed

2 files changed

+51
-3
lines changed

monai/bundle/scripts.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,7 @@ def load(
326326
bundle_dir: Optional[PathLike] = None,
327327
source: str = download_source,
328328
repo: Optional[str] = None,
329+
remove_prefix: Optional[str] = "monai_",
329330
progress: bool = True,
330331
device: Optional[str] = None,
331332
key_in_ckpt: Optional[str] = None,
@@ -356,6 +357,10 @@ def load(
356357
it should be "ngc" or "github".
357358
repo: repo name. This argument is used when `url` is `None` and `source` is "github".
358359
If used, it should be in the form of "repo_owner/repo_name/release_tag".
360+
remove_prefix: This argument is used when `source` is "ngc". Currently, all ngc bundles
361+
have the ``monai_`` prefix, which is not existing in their model zoo contrasts. In order to
362+
maintain the consistency between these two sources, remove prefix is necessary.
363+
Therefore, if specified, downloaded folder name will remove the prefix.
359364
progress: whether to display a progress bar when downloading.
360365
device: target device of returned weights or module, if `None`, prefer to "cuda" if existing.
361366
key_in_ckpt: for nested checkpoint like `{"model": XXX, "optimizer": XXX, ...}`, specify the key of model
@@ -379,9 +384,21 @@ def load(
379384

380385
if model_file is None:
381386
model_file = os.path.join("models", "model.ts" if load_ts_module is True else "model.pt")
387+
if source == "ngc":
388+
name = _add_ngc_prefix(name)
389+
if remove_prefix:
390+
name = _remove_ngc_prefix(name, prefix=remove_prefix)
382391
full_path = os.path.join(bundle_dir_, name, model_file)
383392
if not os.path.exists(full_path):
384-
download(name=name, version=version, bundle_dir=bundle_dir_, source=source, repo=repo, progress=progress)
393+
download(
394+
name=name,
395+
version=version,
396+
bundle_dir=bundle_dir_,
397+
source=source,
398+
repo=repo,
399+
remove_prefix=remove_prefix,
400+
progress=progress,
401+
)
385402

386403
if device is None:
387404
device = "cuda:0" if is_available() else "cpu"

tests/ngc_bundle_download.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919

2020
from monai.apps import check_hash
2121
from monai.apps.mmars import MODEL_DESC, load_from_mmar
22-
from monai.bundle import download
22+
from monai.bundle import download, load
2323
from monai.config import print_debug_info
2424
from monai.networks.utils import copy_model_state
25-
from tests.utils import skip_if_downloading_fails, skip_if_quick, skip_if_windows
25+
from tests.utils import assert_allclose, skip_if_downloading_fails, skip_if_quick, skip_if_windows
2626

2727
TEST_CASE_NGC_1 = [
2828
"spleen_ct_segmentation",
@@ -41,6 +41,30 @@
4141
"b418a2dc8672ce2fd98dc255036e7a3d",
4242
]
4343

44+
TESTCASE_WEIGHTS = {
45+
"key": "model.0.conv.unit0.adn.N.bias",
46+
"value": torch.tensor(
47+
[
48+
-0.0705,
49+
-0.0937,
50+
-0.0422,
51+
-0.2068,
52+
0.1023,
53+
-0.2007,
54+
-0.0883,
55+
0.0018,
56+
-0.1719,
57+
0.0116,
58+
0.0285,
59+
-0.0044,
60+
0.1223,
61+
-0.1287,
62+
-0.1858,
63+
0.0460,
64+
]
65+
),
66+
}
67+
4468

4569
@skip_if_windows
4670
class TestNgcBundleDownload(unittest.TestCase):
@@ -56,6 +80,13 @@ def test_ngc_download_bundle(self, bundle_name, version, remove_prefix, download
5680
self.assertTrue(os.path.exists(full_file_path))
5781
self.assertTrue(check_hash(filepath=full_file_path, val=hash_val))
5882

83+
weights = load(
84+
name=bundle_name, source="ngc", version=version, bundle_dir=tempdir, remove_prefix=remove_prefix
85+
)
86+
assert_allclose(
87+
weights[TESTCASE_WEIGHTS["key"]], TESTCASE_WEIGHTS["value"], atol=1e-4, rtol=1e-4, type_test=False
88+
)
89+
5990

6091
@unittest.skip("deprecating mmar tests")
6192
class TestAllDownloadingMMAR(unittest.TestCase):

0 commit comments

Comments
 (0)