@@ -387,14 +387,10 @@ def set_defaults(self, config_path: str | Path | None = None, **kwargs: Any) ->
387
387
if config_path :
388
388
defaults = read_file (config_path )
389
389
if self .nested_mode == NestedMode .WITHOUT_ROOT and len (self ._wrappers ) == 1 :
390
- # The file should have the same format as the command-line args, e.g. contain the
391
- # fields of the 'root' dataclass directly (e.g. "foo: 123"), rather a dict with
392
- # "config: foo: 123" where foo is a field of the root dataclass at dest 'config'.
393
- # Therefore, we add the prefix back here.
394
- defaults = {self ._wrappers [0 ].dest : defaults }
395
- # We also assume that the kwargs are passed as foo=123
396
- kwargs = {self ._wrappers [0 ].dest : kwargs }
397
- # Also include the values from **kwargs.
390
+ # The file should have the same format as the command-line args
391
+ wrapper = self ._wrappers [0 ]
392
+ defaults = {wrapper .dest : defaults }
393
+ kwargs = {wrapper .dest : kwargs }
398
394
kwargs = dict_union (defaults , kwargs )
399
395
400
396
# The kwargs that are set in the dataclasses, rather than on the namespace.
@@ -640,7 +636,7 @@ def _resolve_subgroups(
640
636
# config_path=self.config_path,
641
637
# NOTE: We disallow abbreviations for subgroups for now. This prevents potential issues
642
638
# for example if you have —a_or_b and A has a field —a then it will error out if you
643
- # pass —a=1 because 1 isn’ t a choice for the a_or_b argument (because --a matches it
639
+ # pass —a=1 because 1 isn' t a choice for the a_or_b argument (because --a matches it
644
640
# with the abbreviation feature turned on).
645
641
allow_abbrev = False ,
646
642
)
@@ -827,6 +823,8 @@ def _instantiate_dataclasses(
827
823
argparse.Namespace
828
824
The transformed namespace with the instances set at their
829
825
corresponding destinations.
826
+ Also keeps whatever arguments were added in the traditional fashion,
827
+ i.e. with `parser.add_argument(...)`.
830
828
"""
831
829
constructor_arguments = constructor_arguments .copy ()
832
830
# FIXME: There's a bug here happening with the `ALWAYS_MERGE` case: The namespace has the
@@ -1157,5 +1155,45 @@ def _create_dataclass_instance(
1157
1155
else :
1158
1156
logger .debug (f"All fields for { wrapper .dest } were either at their default, or None." )
1159
1157
return None
1158
+
1159
+ # Handle subgroup fields
1160
+ subgroup_fields = {f for f in wrapper .fields if f .is_subgroup }
1161
+ if subgroup_fields :
1162
+ # Create a copy of constructor args to avoid modifying the original
1163
+ filtered_args = constructor_args .copy ()
1164
+
1165
+ # Remove _type_ field if present at top level
1166
+ filtered_args .pop ("_type_" , None )
1167
+
1168
+ # For each subgroup field, check if we have parameters that belong to its default type
1169
+ for field_wrapper in subgroup_fields :
1170
+ default_key = field_wrapper .subgroup_default
1171
+ if default_key is not None and default_key is not dataclasses .MISSING :
1172
+ choices = field_wrapper .subgroup_choices
1173
+ default_factory = choices [default_key ]
1174
+ if callable (default_factory ):
1175
+ default_type = default_factory
1176
+ else :
1177
+ default_type = type (default_factory )
1178
+
1179
+ # Get fields of the default type
1180
+ default_subgroup_fields = {f .name for f in dataclasses .fields (default_type )}
1181
+
1182
+ # Find which fields in the input match fields in the default subgroup
1183
+ matching_fields = {name : filtered_args [name ] for name in list (filtered_args .keys ())
1184
+ if name in default_subgroup_fields }
1185
+
1186
+ if matching_fields :
1187
+ # Create an instance of the default type with the matching fields
1188
+ subgroup_instance = default_type (** matching_fields )
1189
+ filtered_args [field_wrapper .name ] = subgroup_instance
1190
+
1191
+ # Remove handled fields
1192
+ for name in matching_fields :
1193
+ filtered_args .pop (name , None )
1194
+
1195
+ # Use the filtered args to create the instance
1196
+ constructor_args = filtered_args
1197
+
1160
1198
logger .debug (f"Calling constructor: { constructor } (**{ constructor_args } )" )
1161
1199
return constructor (** constructor_args )
0 commit comments