Skip to content

Commit da1fed1

Browse files
wr0124beniz
authored andcommitted
fix: warn and exit on invalid command-line arguments
1 parent e96c352 commit da1fed1

File tree

2 files changed

+40
-9
lines changed

2 files changed

+40
-9
lines changed

options/base_options.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ def gather_options(self, args=None, json_args=None):
5858

5959
# get the basic options
6060
opt = self._parse_args(parser, None, args, flat_json)
61-
6261
# get specific options
6362
opt, parser = self._gather_specific_options(opt, parser, args, flat_json)
6463

@@ -94,7 +93,10 @@ def _parse_args(self, parser, opt, args, flat_json, only_known=True):
9493
) # it's not an error anymore because server launching is done with all of the options even those from other models, raising an error will lead to a server crash
9594
else:
9695
# do not ignore unknown options here, they are actual errors in the command line
97-
opt = parser.parse_args(args)
96+
if only_known:
97+
opt, _ = parser.parse_known_args(args)
98+
else:
99+
opt = parser.parse_args(args)
98100
return opt
99101

100102
def _json_parse_known_args(self, parser, opt, json_args):

util/parser.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import sys
12
import json
23
import os
34
import torch
@@ -21,13 +22,13 @@ def get_opt(main_opt, remaining_args):
2122

2223
with open(main_opt.config_json, "r") as jsonf:
2324
train_json = flatten_json(json.load(jsonf))
24-
25-
# Save the config file when --save_config is passed
26-
is_config_saved = False
27-
if "save_config" in override_options_names:
28-
is_config_saved = True
29-
override_options_names.remove("save_config")
30-
remaining_args.remove("--save_config")
25+
# # Save the config file when --save_config is passed
26+
# is_config_saved = False
27+
# if "save_config" in override_options_names:
28+
# is_config_saved = True
29+
# override_options_names.remove("save_config")
30+
# remaining_args.remove("--save_config")
31+
#
3132

3233
if not "--dataroot" in remaining_args:
3334
remaining_args += ["--dataroot", "unused"]
@@ -37,6 +38,24 @@ def get_opt(main_opt, remaining_args):
3738
TrainOptions().parse_to_json(remaining_args)
3839
)
3940

41+
# Save the config file when --save_config is passed
42+
is_config_saved = False
43+
if "save_config" in override_options_names:
44+
is_config_saved = True
45+
override_options_names.remove("save_config")
46+
47+
train_keys = set(train_json.keys())
48+
override_keys = set(override_options_json.keys())
49+
override_names = set(override_options_names)
50+
override_not_in_train = override_names - train_keys
51+
override_not_in_override_json = override_names - override_keys
52+
if override_not_in_override_json:
53+
unknown_list = ", ".join(sorted(override_not_in_override_json))
54+
print(
55+
f"\033[93mWARNING: The following command-line options are not recognized: {unknown_list}\033[0m"
56+
)
57+
sys.exit(1)
58+
4059
for name in override_options_names:
4160
train_json[name] = override_options_json[name]
4261

@@ -46,4 +65,14 @@ def get_opt(main_opt, remaining_args):
4665
else:
4766
opt = TrainOptions().parse() # get training options
4867

68+
parsed_opt_keys = set(vars(opt).keys())
69+
commandline_arg_names = set(get_override_options_names(remaining_args))
70+
always_allowed = {"save_config"}
71+
invalid_args = commandline_arg_names - parsed_opt_keys - always_allowed
72+
if invalid_args:
73+
print(
74+
f"\033[93mWARNING: The following command-line options are not recognized by the parser: {sorted(list(invalid_args))}\033[0m"
75+
)
76+
sys.exit(1) # Stop the script if there are any invalid args
77+
4978
return opt

0 commit comments

Comments
 (0)