Skip to content

Commit 60f5ca9

Browse files
mgarrardmeta-codesync[bot]
authored andcommitted
Update reset_gs code to orphan old gs if not indicated for deletion (facebook#4735)
Summary: Pull Request resolved: facebook#4735 V2: After review we decided to allow for multiple GS to be associated with an experiment in the db and only load the most recent one. V1: I was looking at this method and realized that although it will work within a single instance of client, upon saving/re-loading the experiment load will fail because there will be two instances of gs associated with an experiment and this will cause a merge error. This implements: 1. an orphan gs method -- allows the legacy gs to still live in the db and just sets expeirment column to none, essentially orphaning it, but we coud go look it up in the future if we wanted 2. cleans up some pyres taht weren't doing anything i saw along the way 3. extends tests to account for this behavior Reviewed By: mpolson64 Differential Revision: D90064851 fbshipit-source-id: a5083dbadb4d51e26b546ec1ec690f1ea9daa965
1 parent a42b4ea commit 60f5ca9

4 files changed

Lines changed: 47 additions & 10 deletions

File tree

ax/storage/sqa_store/load.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
# pyre-strict
88

9+
import logging
10+
911
from math import ceil
1012
from typing import Any, cast, Mapping
1113

@@ -53,6 +55,8 @@
5355
from sqlalchemy.orm import defaultload, joinedload, lazyload, noload
5456
from sqlalchemy.orm.exc import DetachedInstanceError
5557

58+
logger: logging.Logger = logging.getLogger(__name__)
59+
5660

5761
# ---------------------------- Loading `Experiment`. ---------------------------
5862

@@ -537,21 +541,32 @@ def _load_generation_strategy_by_id(
537541
def get_generation_strategy_id(experiment_name: str, decoder: Decoder) -> int | None:
538542
"""Get DB ID of the generation strategy, associated with the experiment
539543
with the given name if its in DB, return None otherwise.
544+
545+
If multiple generation strategies are associated with the experiment,
546+
returns the latest one (highest DB ID).
540547
"""
541548
exp_sqa_class = decoder.config.class_to_sqa_class[Experiment]
542549
gs_sqa_class = decoder.config.class_to_sqa_class[GenerationStrategy]
543550
with session_scope() as session:
544-
sqa_gs_id = (
551+
sqa_gs_ids = (
545552
session.query(gs_sqa_class.id) # pyre-ignore[16]
546553
.join(exp_sqa_class.generation_strategy) # pyre-ignore[16]
547554
# pyre-fixme[16]: `SQABase` has no attribute `name`.
548555
.filter(exp_sqa_class.name == experiment_name)
549-
.one_or_none()
556+
.order_by(gs_sqa_class.id.desc())
557+
.all()
550558
)
551559

552-
if sqa_gs_id is None:
560+
if not sqa_gs_ids:
553561
return None
554-
return sqa_gs_id[0]
562+
563+
if len(sqa_gs_ids) > 1:
564+
logger.warning(
565+
f"Found {len(sqa_gs_ids)} generation strategies for experiment "
566+
f"{experiment_name}. Loading the latest one (id={sqa_gs_ids[0][0]})."
567+
)
568+
569+
return sqa_gs_ids[0][0]
555570

556571

557572
def get_generation_strategy_sqa(

ax/storage/sqa_store/sqa_enum.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,13 @@
1616
class BaseNullableEnum(types.TypeDecorator):
1717
cache_ok = True
1818

19-
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
2019
def __init__(self, enum: Any, *arg: list[Any], **kw: dict[Any, Any]) -> None:
2120
types.TypeDecorator.__init__(self, *arg, **kw)
2221
# pyre-fixme[4]: Attribute must be annotated.
2322
self._member_map = enum._member_map_
2423
# pyre-fixme[4]: Attribute must be annotated.
2524
self._value2member_map = enum._value2member_map_
2625

27-
# pyre-fixme[3]: Return annotation cannot be `Any`.
28-
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
2926
def process_bind_param(self, value: Any, dialect: Any) -> Any:
3027
if value is None:
3128
return value
@@ -40,8 +37,6 @@ def process_bind_param(self, value: Any, dialect: Any) -> Any:
4037
)
4138
return val._value_
4239

43-
# pyre-fixme[3]: Return annotation cannot be `Any`.
44-
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
4540
def process_result_value(self, value: Any, dialect: Any) -> Any:
4641
if value is None:
4742
return value

ax/storage/sqa_store/tests/test_sqa_store.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2910,6 +2910,34 @@ def test_delete_generation_strategy_max_gs_to_delete(self) -> None:
29102910
# Full GS fails the equality check
29112911
self.assertEqual(str(generation_strategy), str(loaded_generation_strategy))
29122912

2913+
def test_load_latest_generation_strategy_when_multiple_exist(self) -> None:
2914+
experiment = get_branin_experiment()
2915+
gs1 = choose_generation_strategy_legacy(experiment.search_space)
2916+
gs1.experiment = experiment
2917+
save_experiment(experiment)
2918+
save_generation_strategy(generation_strategy=gs1)
2919+
self.assertEqual(
2920+
gs1.db_id,
2921+
load_generation_strategy_by_experiment_name(experiment.name).db_id,
2922+
)
2923+
2924+
# create a second generation strategy for the experiment
2925+
gs2 = choose_generation_strategy_legacy(experiment.search_space)
2926+
gs2._name = "second_gs"
2927+
gs2.experiment = experiment
2928+
save_generation_strategy(generation_strategy=gs2)
2929+
2930+
# check that the latest generation stragey is loaded
2931+
with self.assertLogs(
2932+
"ax.storage.sqa_store.load", level=logging.WARNING
2933+
) as logs:
2934+
loaded_gs = load_generation_strategy_by_experiment_name(experiment.name)
2935+
self.assertEqual(loaded_gs.db_id, gs2.db_id)
2936+
self.assertEqual(loaded_gs.name, gs2.name)
2937+
self.assertTrue(
2938+
any("Found 2 generation strategies" in log for log in logs.output)
2939+
)
2940+
29132941
def test_query_historical_experiments_given_parameters(self) -> None:
29142942
# This test validates the query behavior for historical experiments.
29152943
config = SQAConfig(experiment_type_enum=TestExperimentTypeEnum)

ax/storage/sqa_store/validation.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ def wrapper(fn: Callable) -> Callable:
5151
return wrapper
5252

5353

54-
# pyre-fixme[3]: Return annotation cannot be `Any`.
5554
def consistency_exactly_one(instance: SQABase, exactly_one_fields: list[str]) -> Any:
5655
"""Ensure that exactly one of `exactly_one_fields` has a value set."""
5756
values = [getattr(instance, field) is not None for field in exactly_one_fields]

0 commit comments

Comments
 (0)