Skip to content

Commit 1bf02d0

Browse files
Jiawei Yangmeta-codesync[bot]
authored andcommitted
Support SQLAlchemy 2.0 in FB storage extensions (#5203) (#5203)
Summary: The Meta-internal Ax storage extensions in ax/fb/storage/ have two SA 2.0 incompatibilities not present in the OSS surface: a raw SQL string passed to session.execute in fb sqa_store db.py (SA 2.0 requires text() wrapping), and external_store.py uses Connection.execute() for writes without explicit transaction (SA 2.0 removed implicit autocommit, so writes were silently rolling back), uses string-keyed Row indexing (SA 2.0 requires row._mapping[key]), and consumes a Result generator outside the connection context (SA 2.0 closes the Result on connection close). This diff wraps SHOW DATABASES with text(), switches _write to engine.begin() for transactional commit, migrates _decode_row to row._mapping access, and materializes the read_raw_data result list inside the with conn block. Adds tests_sa2 dual-version Buck targets for fb sqa_store, fb external_store, and fb prod_tests, plus a SQLAlchemy2CompatTest smoke test that exercises the libfb.py.db_locator -> creator -> engine -> session -> SELECT 1 path and asserts EXPECTED_SA_MAJOR. Pull Request resolved: #5203 Reviewed By: mgarrard, yangjoanna Differential Revision: D104875016 Pulled By: LeoMoonStar fbshipit-source-id: da1d06e90166ede0a04c64a01a86c7d49c84bc8a
1 parent 657841e commit 1bf02d0

9 files changed

Lines changed: 35 additions & 0 deletions

File tree

ax/storage/sqa_store/decoder.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
# pyre-strict
8+
# pyre-ignore-all-errors[6, 8, 9]
89

910
import re
1011
import warnings
@@ -281,13 +282,16 @@ def _init_experiment_from_sqa(
281282
)
282283

283284
return Experiment(
285+
# pyre-ignore[6]: SA 2.0 Column[T] vs plain T param.
284286
name=experiment_sqa.name,
287+
# pyre-ignore[6]: SA 2.0 Column[T] vs plain T param.
285288
description=experiment_sqa.description,
286289
search_space=search_space,
287290
optimization_config=opt_config,
288291
tracking_metrics=all_metrics,
289292
runner=runner,
290293
status_quo=status_quo,
294+
# pyre-ignore[6]: SA 2.0 Column[T] vs plain T param.
291295
is_test=experiment_sqa.is_test,
292296
properties=properties,
293297
auxiliary_experiments_by_purpose=auxiliary_experiments_by_purpose,
@@ -333,11 +337,13 @@ def _init_mt_experiment_from_sqa(
333337
)
334338

335339
default_trial_type = none_throws(experiment_sqa.default_trial_type)
340+
# pyre-ignore[9]: SA 2.0 Column[Optional[str]] keys; runtime str.
336341
trial_type_to_runner: dict[str, Runner | None] = {
337342
none_throws(sqa_runner.trial_type): self.runner_from_sqa(sqa_runner)
338343
for sqa_runner in experiment_sqa.runners
339344
}
340345
if len(trial_type_to_runner) == 0:
346+
# pyre-ignore[9]: SA 2.0 Column[Optional[str]] keys; runtime str.
341347
trial_type_to_runner = {default_trial_type: None}
342348
trial_types_with_metrics = {
343349
metric.trial_type
@@ -347,13 +353,18 @@ def _init_mt_experiment_from_sqa(
347353
# trial_type_to_runner is instantiated to map all trial types to None,
348354
# so the trial types are associated with the experiment. This is
349355
# important for adding metrics.
356+
# pyre-ignore[6]: SA 2.0 Column[T] keys vs str keys.
350357
trial_type_to_runner.update(dict.fromkeys(trial_types_with_metrics))
351358

352359
experiment = MultiTypeExperiment(
360+
# pyre-ignore[6]: SA 2.0 Column[T] vs plain T param.
353361
name=experiment_sqa.name,
362+
# pyre-ignore[6]: SA 2.0 Column[T] vs plain T param.
354363
description=experiment_sqa.description,
355364
search_space=search_space,
365+
# pyre-ignore[6]: SA 2.0 Column[T] vs plain T param.
356366
default_trial_type=default_trial_type,
367+
# pyre-ignore[6]: SA 2.0 Column[T] vs plain T param.
357368
default_runner=trial_type_to_runner.get(default_trial_type),
358369
optimization_config=opt_config,
359370
status_quo=status_quo,

ax/storage/sqa_store/delete.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-strict
7+
# pyre-ignore-all-errors[6]
78

89
from logging import Logger
910
from typing import cast

ax/storage/sqa_store/json.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ class JSONEncodedLongText(JSONEncodedObject):
9595

9696
# pyre-ignore[9]: SA 2.0 typed as_mutable returns TypeEngine; runtime TypeDecorator.
9797
JSONEncodedList: TypeDecorator = MutableList.as_mutable(JSONEncodedObject)
98+
# pyre-ignore[9]: SA 2.0 typed as_mutable returns TypeEngine; runtime TypeDecorator.
9899
JSONEncodedDict: TypeDecorator = MutableDict.as_mutable(JSONEncodedObject)
100+
# pyre-ignore[9]: SA 2.0 typed as_mutable returns TypeEngine; runtime TypeDecorator.
99101
JSONEncodedTextDict: TypeDecorator = MutableDict.as_mutable(JSONEncodedText)
102+
# pyre-ignore[9]: SA 2.0 typed as_mutable returns TypeEngine; runtime TypeDecorator.
100103
JSONEncodedLongTextDict: TypeDecorator = MutableDict.as_mutable(JSONEncodedLongText)

ax/storage/sqa_store/load.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
# pyre-strict
8+
# pyre-ignore-all-errors[6, 8]
89

910
import logging
1011
from collections.abc import Mapping

ax/storage/sqa_store/reduced_state.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
# pyre-strict
8+
# pyre-ignore-all-errors[24]
89

910

1011
from ax.storage.sqa_store.sqa_classes import SQAGeneratorRun, SQATrial

ax/storage/sqa_store/save.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
# pyre-strict
8+
# pyre-ignore-all-errors[6, 8, 9]
89

910
import os
1011
from collections.abc import Callable, Generator, Sequence

ax/storage/sqa_store/sqa_classes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
# pyre-strict
8+
# pyre-ignore-all-errors[8]
89

910
from __future__ import annotations
1011

ax/storage/sqa_store/tests/test_sqa_store.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1841,10 +1841,12 @@ def test_parameter_validation(self) -> None:
18411841
with session_scope() as session:
18421842
session.add(sqa_parameter)
18431843

1844+
# pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine.
18441845
sqa_parameter.experiment_id = 0
18451846
with session_scope() as session:
18461847
session.add(sqa_parameter)
18471848
with self.assertRaises(ValueError):
1849+
# pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine.
18481850
sqa_parameter.generator_run_id = 0
18491851
with session_scope() as session:
18501852
session.add(sqa_parameter)
@@ -1858,6 +1860,7 @@ def test_parameter_validation(self) -> None:
18581860
with session_scope() as session:
18591861
session.add(sqa_parameter)
18601862
with self.assertRaises(ValueError):
1863+
# pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine.
18611864
sqa_parameter.experiment_id = 0
18621865
with session_scope() as session:
18631866
session.add(sqa_parameter)
@@ -1907,10 +1910,12 @@ def test_parameter_constraint_validation(self) -> None:
19071910
with session_scope() as session:
19081911
session.add(sqa_parameter_constraint)
19091912

1913+
# pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine.
19101914
sqa_parameter_constraint.experiment_id = 0
19111915
with session_scope() as session:
19121916
session.add(sqa_parameter_constraint)
19131917
with self.assertRaises(ValueError):
1918+
# pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine.
19141919
sqa_parameter_constraint.generator_run_id = 0
19151920
with session_scope() as session:
19161921
session.add(sqa_parameter_constraint)
@@ -1924,6 +1929,7 @@ def test_parameter_constraint_validation(self) -> None:
19241929
with session_scope() as session:
19251930
session.add(sqa_parameter_constraint)
19261931
with self.assertRaises(ValueError):
1932+
# pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine.
19271933
sqa_parameter_constraint.experiment_id = 0
19281934
with session_scope() as session:
19291935
session.add(sqa_parameter_constraint)
@@ -1961,10 +1967,12 @@ def test_metric_validation(self) -> None:
19611967
with session_scope() as session:
19621968
session.add(sqa_metric)
19631969

1970+
# pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine.
19641971
sqa_metric.experiment_id = 0
19651972
with session_scope() as session:
19661973
session.add(sqa_metric)
19671974
with self.assertRaises(ValueError):
1975+
# pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine.
19681976
sqa_metric.generator_run_id = 0
19691977
with session_scope() as session:
19701978
session.add(sqa_metric)
@@ -1979,6 +1987,7 @@ def test_metric_validation(self) -> None:
19791987
with session_scope() as session:
19801988
session.add(sqa_metric)
19811989
with self.assertRaises(ValueError):
1990+
# pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine.
19821991
sqa_metric.experiment_id = 0
19831992
with session_scope() as session:
19841993
session.add(sqa_metric)
@@ -2025,13 +2034,16 @@ def test_metric_decode_failure(self) -> None:
20252034
with self.assertRaises(SQADecodeError):
20262035
self.decoder.metric_from_sqa(sqa_metric)
20272036

2037+
# pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine.
20282038
sqa_metric.metric_type = CORE_METRIC_REGISTRY[BraninMetric]
20292039
# pyre-fixme[8]: Attribute has type `MetricIntent`; used as `str`.
20302040
sqa_metric.intent = "foobar"
20312041
with self.assertRaises(SQADecodeError):
20322042
self.decoder.metric_from_sqa(sqa_metric)
20332043

2044+
# pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine.
20342045
sqa_metric.intent = MetricIntent.TRACKING
2046+
# pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine.
20352047
sqa_metric.properties = {}
20362048
with self.assertRaises(ValueError):
20372049
self.decoder.metric_from_sqa(sqa_metric)
@@ -2080,10 +2092,12 @@ def test_runner_validation(self) -> None:
20802092
with session_scope() as session:
20812093
session.add(sqa_runner)
20822094

2095+
# pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine.
20832096
sqa_runner.experiment_id = 0
20842097
with session_scope() as session:
20852098
session.add(sqa_runner)
20862099
with self.assertRaises(ValueError):
2100+
# pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine.
20872101
sqa_runner.trial_id = 0
20882102
with session_scope() as session:
20892103
session.add(sqa_runner)
@@ -2094,6 +2108,7 @@ def test_runner_validation(self) -> None:
20942108
with session_scope() as session:
20952109
session.add(sqa_runner)
20962110
with self.assertRaises(ValueError):
2111+
# pyre-ignore[8]: SA 2.0 Column[T] attr; runtime assign is fine.
20972112
sqa_runner.experiment_id = 0
20982113
with session_scope() as session:
20992114
session.add(sqa_runner)

ax/storage/sqa_store/validation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434

3535

3636
def listens_for_multiple(
37+
# pyre-ignore[24]: SA 2.0 requires a type param on InstrumentedAttribute.
3738
targets: list[InstrumentedAttribute],
3839
identifier: str,
3940
*args: Any,

0 commit comments

Comments
 (0)