Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 77 additions & 7 deletions hydra/_internal/core_plugins/basic_sweeper.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@
from typing import Any, Dict, Iterable, List, Optional, Sequence

from omegaconf import DictConfig, OmegaConf
from omegaconf._utils import is_structured_config

from hydra.core.config_store import ConfigStore
from hydra.core.override_parser.overrides_parser import OverridesParser
from hydra.core.override_parser.types import Override
from hydra.core.override_parser.types import Override, QuotedString
from hydra.core.utils import JobReturn
from hydra.errors import HydraException
from hydra.plugins.launcher import Launcher
Expand Down Expand Up @@ -93,6 +94,78 @@ def setup(
config=config,
)

@staticmethod
def simplify_overrides(
overrides: List[Override],
) -> List[Override]:
# this would simplify the overrides by removing those that are overridden later
# in the list.
# e.g. a=1 and later a=10 would remove the first override.
lists = []
# NOTE: key -> index of last override with no dict value. (e.g. a=1,2,3)
# any override for key before this would be skipped.
last_primitive = {}
last_dict = {}
last_defaults: Dict[str, int] = {}

is_defaults: Dict[int, bool] = {}
is_primitive: Dict[int, bool] = {}
has_dict: Dict[int, bool] = {}

# check value should override earlier ones
# TODO: handle extend_list
def check_write_override(x: Any):
return (
isinstance(x, (str, int, float, bool, list, QuotedString)) or x is None
)

def check_has_dict(x: Any):
return isinstance(x, dict) or is_structured_config(x)

for i, override in enumerate(overrides):
if override.config_loader is None:
continue
is_group = (
len(override.config_loader.get_group_options(override.key_or_group)) > 0
)

key = override.get_key_element()
_write = False
_has_dict = False
if override.is_sweep_override():
if override.is_discrete_sweep():
_write = all(override.sweep_iterator(check_write_override))
_has_dict = any(override.sweep_iterator(check_has_dict))
else:
_write = check_write_override(override.value())
_has_dict = check_has_dict(override.value())

if _write:
if is_group:
is_defaults[i] = True
if override.is_change():
last_defaults[key] = i
else:
is_primitive[i] = True
last_primitive[key] = i
if _has_dict:
has_dict[i] = True
last_dict[key] = i

for i, override in enumerate(overrides):
key = override.get_key_element()
if is_primitive.get(i, False) and (
last_primitive.get(key, -1) != i or last_dict.get(key, -1) > i
):
continue
if has_dict.get(i, False) and last_primitive.get(key, -1) > i:
continue
if is_defaults.get(i, False) and last_defaults.get(key, -1) > i:
continue
lists.append(override)

return lists

@staticmethod
def split_overrides_to_chunks(
lst: List[List[str]], n: Optional[int]
Expand All @@ -108,13 +181,13 @@ def split_arguments(
overrides: List[Override], max_batch_size: Optional[int]
) -> List[List[List[str]]]:
lists = []
final_overrides = OrderedDict()
overrides = BasicSweeper.simplify_overrides(overrides)
for override in overrides:
if override.is_sweep_override():
if override.is_discrete_sweep():
key = override.get_key_element()
sweep = [f"{key}={val}" for val in override.sweep_string_iterator()]
final_overrides[key] = sweep
lists.append(sweep)
else:
assert override.value_type is not None
raise HydraException(
Expand All @@ -123,10 +196,7 @@ def split_arguments(
else:
key = override.get_key_element()
value = override.get_value_element_as_str()
final_overrides[key] = [f"{key}={value}"]

for _, v in final_overrides.items():
lists.append(v)
lists.append([f"{key}={value}"])

all_batches = [list(x) for x in itertools.product(*lists)]
assert max_batch_size is None or max_batch_size > 0
Expand Down
6 changes: 6 additions & 0 deletions hydra/core/override_parser/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,12 @@ class Override:
# Configs repo
config_loader: Optional[ConfigLoader] = None

def is_change(self) -> bool:
"""
:return: True if this override represents a change of a config value or config group option
"""
return self.type == OverrideType.CHANGE

def is_delete(self) -> bool:
"""
:return: True if this override represents a deletion of a config value or config group option
Expand Down
91 changes: 90 additions & 1 deletion tests/test_basic_sweeper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@

from pytest import mark, param

from hydra._internal.config_loader_impl import ConfigLoaderImpl
from hydra._internal.core_plugins.basic_sweeper import BasicSweeper
from hydra._internal.utils import create_config_search_path
from hydra.core.override_parser.overrides_parser import OverridesParser
from hydra.test_utils.test_utils import assert_multiline_regex_search, run_process

Expand Down Expand Up @@ -48,19 +50,106 @@
),
param(["a=range(0,3)"], None, [[["a=0"], ["a=1"], ["a=2"]]], id="range"),
param(["a=range(3)"], None, [[["a=0"], ["a=1"], ["a=2"]]], id="range_no_start"),
param(["a=1,2,3", "a=20"], None, [[["a=20"]]], id="override_same_key1"),
param(
["a=2", "a=10,20"], None, [[["a=10"], ["a=20"]]], id="override_same_key2"
),
param(
["a=1,2,3", "a=10,20"],
None,
[[["a=10"], ["a=20"]]],
id="override_same_key3",
),
param(["a={x:1},{x:2}"], None, [[["a={x:1}"], ["a={x:2}"]]], id="dicts"),
param(
["a={x:1},{x:2}", "+a={y:10},{y:20}"],
None,
[
[
["a={x:1}", "+a={y:10}"],
["a={x:1}", "+a={y:20}"],
["a={x:2}", "+a={y:10}"],
["a={x:2}", "+a={y:20}"],
]
],
id="dicts_multiple_with_plus",
),
param(
["a={x:1},{x:2}", "a={y:10},{y:20}"],
None,
[
[
["a={x:1}", "a={y:10}"],
["a={x:1}", "a={y:20}"],
["a={x:2}", "a={y:10}"],
["a={x:2}", "a={y:20}"],
]
],
id="dicts_multiple",
),
param(["a=1,2,3", "a={x:1}"], None, [[["a={x:1}"]]], id="override_with_dict1"),
param(
["a=1,2,3", "a={x:1},{x:2}"],
None,
[[["a={x:1}"], ["a={x:2}"]]],
id="override_with_dict2",
),
param(
["a={x:1}", "a=1,2,3"],
None,
[[["a=1"], ["a=2"], ["a=3"]]],
id="override_with_dict3",
),
param(["a={x:1},{x:2}", "a=1"], None, [[["a=1"]]], id="override_with_dict4"),
],
)
def test_split(
args: List[str], max_batch_size: Optional[int], expected: List[List[List[str]]]
) -> None:
parser = OverridesParser.create()

config_loader = ConfigLoaderImpl(config_search_path=create_config_search_path(None))
parser = OverridesParser.create(config_loader)
ret = BasicSweeper.split_arguments(
parser.parse_overrides(args), max_batch_size=max_batch_size
)
lret = [list(x) for x in ret]
assert lret == expected


@mark.parametrize(
"args,expected",
[
param(["a=1", "b=2", "a=3"], ["b=2", "a=3"], id="simple_override"),
param(["a=1,2", "a=3,4"], ["a=3,4"], id="override_split"),
param(
["a=1", "b=2", "+a={x:10}", "+a={y:20}"],
["a=1", "b=2", "+a={x:10}", "+a={y:20}"],
id="override_plus",
),
param(
["a=1", "b=2", "a={x:10}", "a={y:20}"],
["b=2", "a={x:10}", "a={y:20}"],
id="override_plus",
),
param(["a={x:1}", "a={y:2}"], ["a={x:1}", "a={y:2}"], id="override_dict"),
param(
["a=1,2", "+a={x:10},{y:20}", "a=3,4"],
["+a={x:10},{y:20}", "a=3,4"],
id="override_mixed",
),
param(["a=1,2", "a={x:10},{y:20}", "a=3,4"], ["a=3,4"], id="override_mixed"),
param(["+a=xx,yy", "+a=[zz]"], ["+a=[zz]"], id="override_plus_list"),
],
)
def test_simplify(args: List[str], expected: List[str]) -> None:
config_loader = ConfigLoaderImpl(config_search_path=create_config_search_path(None))
parser = OverridesParser.create(config_loader)
overrides = parser.parse_overrides(args)
simplified = BasicSweeper.simplify_overrides(overrides)
expected_overrides = parser.parse_overrides(expected)
assert simplified == expected_overrides


def test_partial_failure(
tmpdir: Any,
) -> None:
Expand Down
12 changes: 6 additions & 6 deletions tests/test_examples/test_basic_sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@
dedent(
"""\
[HYDRA] Launching 2 jobs locally
[HYDRA] \t#0 : db=mysql db.timeout=5
[HYDRA] \t#0 : db.timeout=5 db=mysql
driver=mysql, timeout=5
[HYDRA] \t#1 : db=mysql db.timeout=10
[HYDRA] \t#1 : db.timeout=10 db=mysql
driver=mysql, timeout=10"""
),
),
Expand All @@ -48,13 +48,13 @@
dedent(
"""\
[HYDRA] Launching 4 jobs locally
[HYDRA] \t#0 : db=mysql db.timeout=5 db.user=one
[HYDRA] \t#0 : db.timeout=5 db=mysql db.user=one
driver=mysql, timeout=5
[HYDRA] \t#1 : db=mysql db.timeout=5 db.user=two
[HYDRA] \t#1 : db.timeout=5 db=mysql db.user=two
driver=mysql, timeout=5
[HYDRA] \t#2 : db=mysql db.timeout=10 db.user=one
[HYDRA] \t#2 : db.timeout=10 db=mysql db.user=one
driver=mysql, timeout=10
[HYDRA] \t#3 : db=mysql db.timeout=10 db.user=two
[HYDRA] \t#3 : db.timeout=10 db=mysql db.user=two
driver=mysql, timeout=10"""
),
),
Expand Down