Skip to content

Commit d00cb03

Browse files
committed
working sol
1 parent 45c309b commit d00cb03

File tree

5 files changed

+111
-28
lines changed

5 files changed

+111
-28
lines changed

Diff for: simple_parsing/parsing.py

+47-9
Original file line numberDiff line numberDiff line change
@@ -387,14 +387,10 @@ def set_defaults(self, config_path: str | Path | None = None, **kwargs: Any) ->
387387
if config_path:
388388
defaults = read_file(config_path)
389389
if self.nested_mode == NestedMode.WITHOUT_ROOT and len(self._wrappers) == 1:
390-
# The file should have the same format as the command-line args, e.g. contain the
391-
# fields of the 'root' dataclass directly (e.g. "foo: 123"), rather a dict with
392-
# "config: foo: 123" where foo is a field of the root dataclass at dest 'config'.
393-
# Therefore, we add the prefix back here.
394-
defaults = {self._wrappers[0].dest: defaults}
395-
# We also assume that the kwargs are passed as foo=123
396-
kwargs = {self._wrappers[0].dest: kwargs}
397-
# Also include the values from **kwargs.
390+
# The file should have the same format as the command-line args
391+
wrapper = self._wrappers[0]
392+
defaults = {wrapper.dest: defaults}
393+
kwargs = {wrapper.dest: kwargs}
398394
kwargs = dict_union(defaults, kwargs)
399395

400396
# The kwargs that are set in the dataclasses, rather than on the namespace.
@@ -640,7 +636,7 @@ def _resolve_subgroups(
640636
# config_path=self.config_path,
641637
# NOTE: We disallow abbreviations for subgroups for now. This prevents potential issues
642638
# for example if you have —a_or_b and A has a field —a then it will error out if you
643-
# pass —a=1 because 1 isnt a choice for the a_or_b argument (because --a matches it
639+
# pass —a=1 because 1 isn't a choice for the a_or_b argument (because --a matches it
644640
# with the abbreviation feature turned on).
645641
allow_abbrev=False,
646642
)
@@ -827,6 +823,8 @@ def _instantiate_dataclasses(
827823
argparse.Namespace
828824
The transformed namespace with the instances set at their
829825
corresponding destinations.
826+
Also keeps whatever arguments were added in the traditional fashion,
827+
i.e. with `parser.add_argument(...)`.
830828
"""
831829
constructor_arguments = constructor_arguments.copy()
832830
# FIXME: There's a bug here happening with the `ALWAYS_MERGE` case: The namespace has the
@@ -1157,5 +1155,45 @@ def _create_dataclass_instance(
11571155
else:
11581156
logger.debug(f"All fields for {wrapper.dest} were either at their default, or None.")
11591157
return None
1158+
1159+
# Handle subgroup fields
1160+
subgroup_fields = {f for f in wrapper.fields if f.is_subgroup}
1161+
if subgroup_fields:
1162+
# Create a copy of constructor args to avoid modifying the original
1163+
filtered_args = constructor_args.copy()
1164+
1165+
# Remove _type_ field if present at top level
1166+
filtered_args.pop("_type_", None)
1167+
1168+
# For each subgroup field, check if we have parameters that belong to its default type
1169+
for field_wrapper in subgroup_fields:
1170+
default_key = field_wrapper.subgroup_default
1171+
if default_key is not None and default_key is not dataclasses.MISSING:
1172+
choices = field_wrapper.subgroup_choices
1173+
default_factory = choices[default_key]
1174+
if callable(default_factory):
1175+
default_type = default_factory
1176+
else:
1177+
default_type = type(default_factory)
1178+
1179+
# Get fields of the default type
1180+
default_subgroup_fields = {f.name for f in dataclasses.fields(default_type)}
1181+
1182+
# Find which fields in the input match fields in the default subgroup
1183+
matching_fields = {name: filtered_args[name] for name in list(filtered_args.keys())
1184+
if name in default_subgroup_fields}
1185+
1186+
if matching_fields:
1187+
# Create an instance of the default type with the matching fields
1188+
subgroup_instance = default_type(**matching_fields)
1189+
filtered_args[field_wrapper.name] = subgroup_instance
1190+
1191+
# Remove handled fields
1192+
for name in matching_fields:
1193+
filtered_args.pop(name, None)
1194+
1195+
# Use the filtered args to create the instance
1196+
constructor_args = filtered_args
1197+
11601198
logger.debug(f"Calling constructor: {constructor}(**{constructor_args})")
11611199
return constructor(**constructor_args)

Diff for: simple_parsing/wrappers/dataclass_wrapper.py

+55-13
Original file line numberDiff line numberDiff line change
@@ -294,24 +294,66 @@ def set_default(self, value: DataclassT | dict | None):
294294
self._default = value
295295
if field_default_values is None:
296296
return
297-
unknown_names = set(field_default_values)
297+
298+
# First try to handle any subgroup fields
299+
subgroup_fields = {f for f in self.fields if f.is_subgroup}
300+
remaining_fields = field_default_values.copy() # Work with a copy to track what's been handled
301+
302+
for field_wrapper in subgroup_fields:
303+
# Get the default subgroup type from the choices
304+
default_key = field_wrapper.subgroup_default
305+
if default_key is not None and default_key is not dataclasses.MISSING:
306+
choices = field_wrapper.subgroup_choices
307+
default_factory = choices[default_key]
308+
if callable(default_factory):
309+
default_type = default_factory
310+
else:
311+
default_type = type(default_factory)
312+
313+
# Get fields of the default type
314+
default_subgroup_fields = {f.name for f in dataclasses.fields(default_type)}
315+
316+
# Find which fields in the input match fields in the default subgroup
317+
matching_fields = {name: remaining_fields[name] for name in list(remaining_fields.keys())
318+
if name in default_subgroup_fields}
319+
320+
if matching_fields:
321+
# Create the nested structure for the subgroup
322+
subgroup_dict = {
323+
field_wrapper.name: {
324+
"_type_": default_key,
325+
**matching_fields
326+
}
327+
}
328+
# Set this as the default for this field
329+
field_wrapper.set_default(subgroup_dict[field_wrapper.name])
330+
331+
# Remove handled fields
332+
for name in matching_fields:
333+
remaining_fields.pop(name, None)
334+
335+
# Now handle any remaining regular fields
298336
for field_wrapper in self.fields:
299-
if field_wrapper.name not in field_default_values:
337+
if field_wrapper.name not in remaining_fields:
300338
continue
301-
# Manually set the default value for this argument.
302-
field_default_value = field_default_values[field_wrapper.name]
303-
field_wrapper.set_default(field_default_value)
304-
unknown_names.remove(field_wrapper.name)
339+
if field_wrapper.is_subgroup:
340+
continue
341+
# Set default for regular field
342+
field_wrapper.set_default(remaining_fields[field_wrapper.name])
343+
remaining_fields.pop(field_wrapper.name)
344+
345+
# Handle nested dataclass fields
305346
for nested_dataclass_wrapper in self._children:
306-
if nested_dataclass_wrapper.name not in field_default_values:
347+
if nested_dataclass_wrapper.name not in remaining_fields:
307348
continue
308-
field_default_value = field_default_values[nested_dataclass_wrapper.name]
309-
nested_dataclass_wrapper.set_default(field_default_value)
310-
unknown_names.remove(nested_dataclass_wrapper.name)
311-
unknown_names.discard("_type_")
312-
if unknown_names:
349+
nested_dataclass_wrapper.set_default(remaining_fields[nested_dataclass_wrapper.name])
350+
remaining_fields.pop(nested_dataclass_wrapper.name)
351+
352+
# Check for any unhandled fields
353+
remaining_fields.pop("_type_", None) # Remove _type_ if present as it's handled separately
354+
if remaining_fields:
313355
raise RuntimeError(
314-
f"{sorted(unknown_names)} are not fields of {self.dataclass} at path {self.dest!r}!"
356+
f"{sorted(remaining_fields.keys())} are not fields of {self.dataclass} at path {self.dest!r}!"
315357
)
316358

317359
@property

Diff for: simple_parsing/wrappers/field_wrapper.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -719,12 +719,13 @@ def default(self) -> Any:
719719
if it has a default value
720720
"""
721721

722-
if self._default is not None:
722+
if self.is_subgroup:
723+
# For subgroups, always use the subgroup_default to maintain consistency
724+
default = self.subgroup_default
725+
elif self._default is not None:
723726
# If a default value was set manually from the outside (e.g. from the DataclassWrapper)
724727
# then use that value.
725728
default = self._default
726-
elif self.is_subgroup:
727-
default = self.subgroup_default
728729
elif any(
729730
parent_default not in (None, argparse.SUPPRESS)
730731
for parent_default in self.parent.defaults

Diff for: test/test_subgroup_minimal.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1+
_type_: type_a
12
model_a_param: test

Diff for: test/test_subgroups.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -964,7 +964,8 @@ class TrainConfig(TestSetup):
964964
# Create a config file
965965
config_path = Path(__file__).parent / "test_subgroup_minimal.yaml"
966966
config = {
967-
"model_a_param": "test" # This should work but currently fails
967+
"_type_": "type_a", # Specify we want to use ModelTypeA
968+
"model_a_param": "test" # Set the parameter
968969
}
969970
with config_path.open('w') as f:
970971
yaml.dump(config, f)
@@ -980,8 +981,8 @@ class TrainConfig(TestSetup):
980981
# This should work the same way as CLI args
981982
config_from_file = parse(
982983
TrainConfig,
983-
args=shlex.split(f"--config_path {config_path}"),
984-
add_config_path_arg=True,
984+
config_path=config_path,
985+
args=[], # Pass empty list to prevent pytest args from being parsed
985986
)
986987

987988
# These assertions should pass but currently fail because the config file parameters

0 commit comments

Comments
 (0)