Skip to content

Commit 865972f

Browse files
authored
Fix unsuitable default value to bundle_root in ckpt_export (#7124)
Fixes #7123 ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: KumoLiu <[email protected]>
1 parent e36982b commit 865972f

File tree

1 file changed

+11
-12
lines changed

1 file changed

+11
-12
lines changed

Diff for: monai/bundle/scripts.py

+11-12
Original file line numberDiff line numberDiff line change
@@ -1273,7 +1273,6 @@ def ckpt_export(
12731273
config_file_,
12741274
filepath_,
12751275
ckpt_file_,
1276-
bundle_root_,
12771276
net_id_,
12781277
meta_file_,
12791278
key_in_ckpt_,
@@ -1285,26 +1284,30 @@ def ckpt_export(
12851284
"config_file",
12861285
filepath=None,
12871286
ckpt_file=None,
1288-
bundle_root=os.getcwd(),
12891287
net_id=None,
12901288
meta_file=None,
12911289
key_in_ckpt="",
12921290
use_trace=False,
12931291
input_shape=None,
12941292
converter_kwargs={},
12951293
)
1294+
bundle_root = _args.get("bundle_root", os.getcwd())
12961295

12971296
parser = ConfigParser()
1298-
12991297
parser.read_config(f=config_file_)
1300-
meta_file_ = os.path.join(bundle_root_, "configs", "metadata.json") if meta_file_ is None else meta_file_
1301-
filepath_ = os.path.join(bundle_root_, "models", "model.ts") if filepath_ is None else filepath_
1302-
ckpt_file_ = os.path.join(bundle_root_, "models", "model.pt") if ckpt_file_ is None else ckpt_file_
1303-
if not os.path.exists(ckpt_file_):
1304-
raise FileNotFoundError(f'Checkpoint file "{ckpt_file_}" not found, please specify it in argument "ckpt_file".')
1298+
meta_file_ = os.path.join(bundle_root, "configs", "metadata.json") if meta_file_ is None else meta_file_
13051299
if os.path.exists(meta_file_):
13061300
parser.read_meta(f=meta_file_)
13071301

1302+
# the rest key-values in the _args are to override config content
1303+
for k, v in _args.items():
1304+
parser[k] = v
1305+
1306+
filepath_ = os.path.join(bundle_root, "models", "model.ts") if filepath_ is None else filepath_
1307+
ckpt_file_ = os.path.join(bundle_root, "models", "model.pt") if ckpt_file_ is None else ckpt_file_
1308+
if not os.path.exists(ckpt_file_):
1309+
raise FileNotFoundError(f'Checkpoint file "{ckpt_file_}" not found, please specify it in argument "ckpt_file".')
1310+
13081311
net_id_ = "network_def" if net_id_ is None else net_id_
13091312
try:
13101313
parser.get_parsed_content(net_id_)
@@ -1313,10 +1316,6 @@ def ckpt_export(
13131316
f'Network definition "{net_id_}" cannot be found in "{config_file_}", specify name with argument "net_id".'
13141317
) from e
13151318

1316-
# the rest key-values in the _args are to override config content
1317-
for k, v in _args.items():
1318-
parser[k] = v
1319-
13201319
# When export through torch.jit.trace without providing input_shape, will try to parse one from the parser.
13211320
if (not input_shape_) and use_trace:
13221321
input_shape_ = _get_fake_input_shape(parser=parser)

0 commit comments

Comments
 (0)