Skip to content

Commit 93ff6d4

Browse files
xingyousongcopybara-github
authored andcommitted
1. Upgrade to Py 3.12
2. Fix SQL datastore's commit logic 3. Fix performance_test.py PiperOrigin-RevId: 721981221
1 parent 615bb2a commit 93ff6d4

File tree

6 files changed

+73
-62
lines changed

6 files changed

+73
-62
lines changed

.github/workflows/ci.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ jobs:
1515
strategy:
1616
fail-fast: false
1717
matrix:
18-
python-version: ["3.11"] # 3.x disabled b/c of 3.12 test failures w/ GRPC.
18+
python-version: ["3.x"]
1919
suffix: ["core", "benchmarks", "algorithms", "clients", "pyglove", "raytune"]
2020
include:
2121
- suffix: "clients"

.github/workflows/pypi-publish-dev.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ jobs:
1717
- name: Set up Python
1818
uses: actions/setup-python@v4
1919
with:
20-
python-version: '3.11'
20+
python-version: '3.12'
2121
- name: Install dependencies
2222
# NOTE: grpcio-tools needs to be periodically updated to support later Python versions.
2323
run: |

.github/workflows/pypi-publish.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ jobs:
1515
- name: Set up Python
1616
uses: actions/setup-python@v4
1717
with:
18-
python-version: '3.11'
18+
python-version: '3.12'
1919
- name: Install dependencies
2020
# NOTE: grpcio-tools needs to be periodically updated to support later Python versions.
2121
run: |

vizier/_src/service/performance_test.py

+5-13
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818

1919
import multiprocessing.pool
2020
import time
21-
from absl import logging
2221

22+
from absl import logging
2323
from vizier._src.service import constants
2424
from vizier._src.service import vizier_client
2525
from vizier._src.service import vizier_server
@@ -41,23 +41,15 @@ def setUpClass(cls):
4141
)
4242
vizier_client.environment_variables.server_endpoint = cls.server.endpoint
4343

44-
@parameterized.parameters(
45-
(1, 10, 2),
46-
(2, 10, 2),
47-
(10, 10, 2),
48-
(50, 5, 2),
49-
(100, 5, 2),
50-
)
44+
@parameterized.parameters((1, 10), (2, 10), (10, 10), (50, 5), (100, 5))
5145
def test_multiple_clients_basic(
52-
self, num_simultaneous_clients, num_trials_per_client, dimension
46+
self, num_simultaneous_clients, num_trials_per_client
5347
):
5448
def fn(client_id: int):
55-
experimenter = experimenters.BBOBExperimenterFactory(
56-
'Sphere', dimension
57-
)()
49+
experimenter = experimenters.BBOBExperimenterFactory('Sphere', 2)()
5850
problem_statement = experimenter.problem_statement()
5951
study_config = pyvizier.StudyConfig.from_problem(problem_statement)
60-
study_config.algorithm = pyvizier.Algorithm.NSGA2
52+
study_config.algorithm = pyvizier.Algorithm.RANDOM_SEARCH
6153

6254
client = vizier_client.create_or_load_study(
6355
owner_id='my_username',

vizier/_src/service/sql_datastore.py

+61-45
Original file line numberDiff line numberDiff line change
@@ -103,20 +103,19 @@ def create_study(self, study: study_pb2.Study) -> resources.StudyResource:
103103

104104
with self._lock:
105105
try:
106-
self._connection.execute(owner_query)
107-
self._connection.commit()
106+
self._write_or_rollback(owner_query)
108107
except sqla.exc.IntegrityError:
109108
logging.info('Owner with name %s currently exists.', owner_name)
110-
self._connection.rollback()
109+
111110
try:
112-
self._connection.execute(study_query)
113-
self._connection.commit()
114-
return study_resource
111+
self._write_or_rollback(study_query)
115112
except sqla.exc.IntegrityError as e:
116-
self._connection.rollback()
117113
raise AlreadyExistsError(
118114
'Study with name %s already exists.' % study.name
119115
) from e
116+
self._connection.commit()
117+
118+
return study_resource
120119

121120
def load_study(self, study_name: str) -> study_pb2.Study:
122121
query = sqla.select(self._studies_table)
@@ -151,8 +150,9 @@ def update_study(self, study: study_pb2.Study) -> resources.StudyResource:
151150
with self._lock:
152151
if not self._connection.execute(eq).fetchone()[0]:
153152
raise NotFoundError('Study %s does not exist.' % study.name)
154-
self._connection.execute(uq)
153+
self._write_or_rollback(uq)
155154
self._connection.commit()
155+
156156
return study_resource
157157

158158
def delete_study(self, study_name: str) -> None:
@@ -175,8 +175,8 @@ def delete_study(self, study_name: str) -> None:
175175
with self._lock:
176176
if not self._connection.execute(eq).fetchone()[0]:
177177
raise NotFoundError('Study %s does not exist.' % study_name)
178-
self._connection.execute(dsq)
179-
self._connection.execute(dtq)
178+
self._write_or_rollback(dsq)
179+
self._write_or_rollback(dtq)
180180
self._connection.commit()
181181

182182
def list_studies(self, owner_name: str) -> List[study_pb2.Study]:
@@ -210,14 +210,14 @@ def create_trial(self, trial: study_pb2.Trial) -> resources.TrialResource:
210210

211211
with self._lock:
212212
try:
213-
self._connection.execute(query)
214-
self._connection.commit()
215-
return trial_resource
213+
self._write_or_rollback(query)
216214
except sqla.exc.IntegrityError as e:
217-
self._connection.rollback()
218215
raise AlreadyExistsError(
219216
'Trial with name %s already exists.' % trial.name
220217
) from e
218+
self._connection.commit()
219+
220+
return trial_resource
221221

222222
def get_trial(self, trial_name: str) -> study_pb2.Trial:
223223
query = sqla.select(self._trials_table)
@@ -253,7 +253,7 @@ def update_trial(self, trial: study_pb2.Trial) -> resources.TrialResource:
253253
with self._lock:
254254
if not self._connection.execute(eq).fetchone()[0]:
255255
raise NotFoundError('Trial %s does not exist.' % trial.name)
256-
self._connection.execute(uq)
256+
self._write_or_rollback(uq)
257257
self._connection.commit()
258258

259259
return trial_resource
@@ -291,7 +291,7 @@ def delete_trial(self, trial_name: str) -> None:
291291
with self._lock:
292292
if not self._connection.execute(eq).fetchone()[0]:
293293
raise NotFoundError('Trial %s does not exist.' % trial_name)
294-
self._connection.execute(dq)
294+
self._write_or_rollback(dq)
295295
self._connection.commit()
296296

297297
def max_trial_id(self, study_name: str) -> int:
@@ -330,16 +330,16 @@ def create_suggestion_operation(
330330
serialized_op=operation.SerializeToString(),
331331
)
332332

333-
try:
334-
with self._lock:
335-
self._connection.execute(query)
336-
self._connection.commit()
337-
return resource
338-
except sqla.exc.IntegrityError as e:
339-
self._connection.rollback()
340-
raise AlreadyExistsError(
341-
'Suggest Op with name %s already exists.' % operation.name
342-
) from e
333+
with self._lock:
334+
try:
335+
self._write_or_rollback(query)
336+
except sqla.exc.IntegrityError as e:
337+
raise AlreadyExistsError(
338+
'Suggest Op with name %s already exists.' % operation.name
339+
) from e
340+
self._connection.commit()
341+
342+
return resource
343343

344344
def get_suggestion_operation(
345345
self, operation_name: str
@@ -386,8 +386,9 @@ def update_suggestion_operation(
386386
with self._lock:
387387
if not self._connection.execute(eq).fetchone()[0]:
388388
raise NotFoundError('Suggest op %s does not exist.' % operation.name)
389-
self._connection.execute(uq)
389+
self._write_or_rollback(uq)
390390
self._connection.commit()
391+
391392
return resource
392393

393394
def list_suggestion_operations(
@@ -474,16 +475,16 @@ def create_early_stopping_operation(
474475
serialized_op=operation.SerializeToString(),
475476
)
476477

477-
try:
478-
with self._lock:
479-
self._connection.execute(query)
480-
self._connection.commit()
481-
return resource
482-
except sqla.exc.IntegrityError as e:
483-
self._connection.rollback()
484-
raise AlreadyExistsError(
485-
'Early stopping op with name %s already exists.' % operation.name
486-
) from e
478+
with self._lock:
479+
try:
480+
self._write_or_rollback(query)
481+
except sqla.exc.IntegrityError as e:
482+
raise AlreadyExistsError(
483+
'Early stopping op with name %s already exists.' % operation.name
484+
) from e
485+
self._connection.commit()
486+
487+
return resource
487488

488489
def get_early_stopping_operation(
489490
self, operation_name: str
@@ -553,8 +554,7 @@ def update_metadata(
553554
sq = sq.where(self._studies_table.c.study_name == study_name)
554555

555556
with self._lock:
556-
study_result = self._connection.execute(sq)
557-
row = study_result.fetchone()
557+
row = self._connection.execute(sq).fetchone()
558558
if not row:
559559
raise NotFoundError('No such study:', s_resource.name)
560560
original_study = study_pb2.Study.FromString(row.serialized_study)
@@ -567,8 +567,7 @@ def update_metadata(
567567
usq = sqla.update(self._studies_table)
568568
usq = usq.where(self._studies_table.c.study_name == study_name)
569569
usq = usq.values(serialized_study=original_study.SerializeToString())
570-
self._connection.execute(usq)
571-
self._connection.commit()
570+
self._write_or_rollback(usq)
572571

573572
# Split the trial-related metadata by Trial.
574573
split_metadata = collections.defaultdict(list)
@@ -583,9 +582,9 @@ def update_metadata(
583582
# Obtain original trial.
584583
otq = sqla.select(self._trials_table)
585584
otq = otq.where(self._trials_table.c.trial_name == trial_name)
586-
trial_result = self._connection.execute(otq)
587-
row = trial_result.fetchone()
585+
row = self._connection.execute(otq).fetchone()
588586
if not row:
587+
self._connection.rollback()
589588
raise NotFoundError('No such trial:', trial_name)
590589
original_trial = study_pb2.Trial.FromString(row.serialized_trial)
591590

@@ -594,5 +593,22 @@ def update_metadata(
594593
utq = sqla.update(self._trials_table)
595594
utq = utq.where(self._trials_table.c.trial_name == trial_name)
596595
utq = utq.values(serialized_trial=original_trial.SerializeToString())
597-
self._connection.execute(utq)
598-
self._connection.commit()
596+
self._write_or_rollback(utq)
597+
598+
# Commit ALL changes if everything went well.
599+
self._connection.commit()
600+
601+
def _write_or_rollback(self, write_query: sqla.sql.Executable) -> None:
602+
"""Wraps connection.execute() to roll back on write query failure.
603+
604+
Args:
605+
write_query: The write query to execute.
606+
607+
Raises:
608+
sqla.exc.DatabaseError: Generic database error.
609+
"""
610+
try:
611+
self._connection.execute(write_query)
612+
except sqla.exc.DatabaseError as e:
613+
self._connection.rollback()
614+
raise e

vizier/_src/service/vizier_server.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from concurrent import futures
2323
import datetime
2424
import time
25+
from typing import Optional
2526

2627
import attr
2728
import grpc
@@ -48,7 +49,9 @@ class DefaultVizierServer:
4849
"""
4950

5051
_host: str = attr.field(default='localhost')
51-
_database_url: str = attr.field(default=constants.SQL_LOCAL_URL, kw_only=True)
52+
_database_url: Optional[str] = attr.field(
53+
default=constants.SQL_LOCAL_URL, kw_only=True
54+
)
5255
_policy_factory: pythia.PolicyFactory = attr.field(
5356
factory=service_policy_factory_lib.DefaultPolicyFactory, kw_only=True
5457
)

0 commit comments

Comments
 (0)