@@ -103,20 +103,19 @@ def create_study(self, study: study_pb2.Study) -> resources.StudyResource:
103
103
104
104
with self ._lock :
105
105
try :
106
- self ._connection .execute (owner_query )
107
- self ._connection .commit ()
106
+ self ._write_or_rollback (owner_query )
108
107
except sqla .exc .IntegrityError :
109
108
logging .info ('Owner with name %s currently exists.' , owner_name )
110
- self . _connection . rollback ()
109
+
111
110
try :
112
- self ._connection .execute (study_query )
113
- self ._connection .commit ()
114
- return study_resource
111
+ self ._write_or_rollback (study_query )
115
112
except sqla .exc .IntegrityError as e :
116
- self ._connection .rollback ()
117
113
raise AlreadyExistsError (
118
114
'Study with name %s already exists.' % study .name
119
115
) from e
116
+ self ._connection .commit ()
117
+
118
+ return study_resource
120
119
121
120
def load_study (self , study_name : str ) -> study_pb2 .Study :
122
121
query = sqla .select (self ._studies_table )
@@ -151,8 +150,9 @@ def update_study(self, study: study_pb2.Study) -> resources.StudyResource:
151
150
with self ._lock :
152
151
if not self ._connection .execute (eq ).fetchone ()[0 ]:
153
152
raise NotFoundError ('Study %s does not exist.' % study .name )
154
- self ._connection . execute (uq )
153
+ self ._write_or_rollback (uq )
155
154
self ._connection .commit ()
155
+
156
156
return study_resource
157
157
158
158
def delete_study (self , study_name : str ) -> None :
@@ -175,8 +175,8 @@ def delete_study(self, study_name: str) -> None:
175
175
with self ._lock :
176
176
if not self ._connection .execute (eq ).fetchone ()[0 ]:
177
177
raise NotFoundError ('Study %s does not exist.' % study_name )
178
- self ._connection . execute (dsq )
179
- self ._connection . execute (dtq )
178
+ self ._write_or_rollback (dsq )
179
+ self ._write_or_rollback (dtq )
180
180
self ._connection .commit ()
181
181
182
182
def list_studies (self , owner_name : str ) -> List [study_pb2 .Study ]:
@@ -210,14 +210,14 @@ def create_trial(self, trial: study_pb2.Trial) -> resources.TrialResource:
210
210
211
211
with self ._lock :
212
212
try :
213
- self ._connection .execute (query )
214
- self ._connection .commit ()
215
- return trial_resource
213
+ self ._write_or_rollback (query )
216
214
except sqla .exc .IntegrityError as e :
217
- self ._connection .rollback ()
218
215
raise AlreadyExistsError (
219
216
'Trial with name %s already exists.' % trial .name
220
217
) from e
218
+ self ._connection .commit ()
219
+
220
+ return trial_resource
221
221
222
222
def get_trial (self , trial_name : str ) -> study_pb2 .Trial :
223
223
query = sqla .select (self ._trials_table )
@@ -253,7 +253,7 @@ def update_trial(self, trial: study_pb2.Trial) -> resources.TrialResource:
253
253
with self ._lock :
254
254
if not self ._connection .execute (eq ).fetchone ()[0 ]:
255
255
raise NotFoundError ('Trial %s does not exist.' % trial .name )
256
- self ._connection . execute (uq )
256
+ self ._write_or_rollback (uq )
257
257
self ._connection .commit ()
258
258
259
259
return trial_resource
@@ -291,7 +291,7 @@ def delete_trial(self, trial_name: str) -> None:
291
291
with self ._lock :
292
292
if not self ._connection .execute (eq ).fetchone ()[0 ]:
293
293
raise NotFoundError ('Trial %s does not exist.' % trial_name )
294
- self ._connection . execute (dq )
294
+ self ._write_or_rollback (dq )
295
295
self ._connection .commit ()
296
296
297
297
def max_trial_id (self , study_name : str ) -> int :
@@ -330,16 +330,16 @@ def create_suggestion_operation(
330
330
serialized_op = operation .SerializeToString (),
331
331
)
332
332
333
- try :
334
- with self . _lock :
335
- self ._connection . execute (query )
336
- self . _connection . commit ()
337
- return resource
338
- except sqla . exc . IntegrityError as e :
339
- self . _connection . rollback ()
340
- raise AlreadyExistsError (
341
- 'Suggest Op with name %s already exists.' % operation . name
342
- ) from e
333
+ with self . _lock :
334
+ try :
335
+ self ._write_or_rollback (query )
336
+ except sqla . exc . IntegrityError as e :
337
+ raise AlreadyExistsError (
338
+ 'Suggest Op with name %s already exists.' % operation . name
339
+ ) from e
340
+ self . _connection . commit ()
341
+
342
+ return resource
343
343
344
344
def get_suggestion_operation (
345
345
self , operation_name : str
@@ -386,8 +386,9 @@ def update_suggestion_operation(
386
386
with self ._lock :
387
387
if not self ._connection .execute (eq ).fetchone ()[0 ]:
388
388
raise NotFoundError ('Suggest op %s does not exist.' % operation .name )
389
- self ._connection . execute (uq )
389
+ self ._write_or_rollback (uq )
390
390
self ._connection .commit ()
391
+
391
392
return resource
392
393
393
394
def list_suggestion_operations (
@@ -474,16 +475,16 @@ def create_early_stopping_operation(
474
475
serialized_op = operation .SerializeToString (),
475
476
)
476
477
477
- try :
478
- with self . _lock :
479
- self ._connection . execute (query )
480
- self . _connection . commit ()
481
- return resource
482
- except sqla . exc . IntegrityError as e :
483
- self . _connection . rollback ()
484
- raise AlreadyExistsError (
485
- 'Early stopping op with name %s already exists.' % operation . name
486
- ) from e
478
+ with self . _lock :
479
+ try :
480
+ self ._write_or_rollback (query )
481
+ except sqla . exc . IntegrityError as e :
482
+ raise AlreadyExistsError (
483
+ 'Early stopping op with name %s already exists.' % operation . name
484
+ ) from e
485
+ self . _connection . commit ()
486
+
487
+ return resource
487
488
488
489
def get_early_stopping_operation (
489
490
self , operation_name : str
@@ -553,8 +554,7 @@ def update_metadata(
553
554
sq = sq .where (self ._studies_table .c .study_name == study_name )
554
555
555
556
with self ._lock :
556
- study_result = self ._connection .execute (sq )
557
- row = study_result .fetchone ()
557
+ row = self ._connection .execute (sq ).fetchone ()
558
558
if not row :
559
559
raise NotFoundError ('No such study:' , s_resource .name )
560
560
original_study = study_pb2 .Study .FromString (row .serialized_study )
@@ -567,8 +567,7 @@ def update_metadata(
567
567
usq = sqla .update (self ._studies_table )
568
568
usq = usq .where (self ._studies_table .c .study_name == study_name )
569
569
usq = usq .values (serialized_study = original_study .SerializeToString ())
570
- self ._connection .execute (usq )
571
- self ._connection .commit ()
570
+ self ._write_or_rollback (usq )
572
571
573
572
# Split the trial-related metadata by Trial.
574
573
split_metadata = collections .defaultdict (list )
@@ -583,9 +582,9 @@ def update_metadata(
583
582
# Obtain original trial.
584
583
otq = sqla .select (self ._trials_table )
585
584
otq = otq .where (self ._trials_table .c .trial_name == trial_name )
586
- trial_result = self ._connection .execute (otq )
587
- row = trial_result .fetchone ()
585
+ row = self ._connection .execute (otq ).fetchone ()
588
586
if not row :
587
+ self ._connection .rollback ()
589
588
raise NotFoundError ('No such trial:' , trial_name )
590
589
original_trial = study_pb2 .Trial .FromString (row .serialized_trial )
591
590
@@ -594,5 +593,22 @@ def update_metadata(
594
593
utq = sqla .update (self ._trials_table )
595
594
utq = utq .where (self ._trials_table .c .trial_name == trial_name )
596
595
utq = utq .values (serialized_trial = original_trial .SerializeToString ())
597
- self ._connection .execute (utq )
598
- self ._connection .commit ()
596
+ self ._write_or_rollback (utq )
597
+
598
+ # Commit ALL changes if everything went well.
599
+ self ._connection .commit ()
600
+
601
+ def _write_or_rollback (self , write_query : sqla .sql .Executable ) -> None :
602
+ """Wraps connection.execute() to roll back on write query failure.
603
+
604
+ Args:
605
+ write_query: The write query to execute.
606
+
607
+ Raises:
608
+ sqla.exc.DatabaseError: Generic database error.
609
+ """
610
+ try :
611
+ self ._connection .execute (write_query )
612
+ except sqla .exc .DatabaseError as e :
613
+ self ._connection .rollback ()
614
+ raise e
0 commit comments