1+ from __future__ import annotations
2+
3+ from collections .abc import Container
4+ from collections .abc import Sequence
15import copy
26import datetime
37import enum
48import pickle
59import threading
610from typing import Any
7- from typing import Container
8- from typing import Dict
9- from typing import List
10- from typing import Optional
11- from typing import Sequence
1211import uuid
1312
1413import optuna
@@ -109,22 +108,22 @@ def __init__(self, log_storage: BaseJournalBackend) -> None:
109108 self .restore_replay_result (snapshot )
110109 self ._sync_with_backend ()
111110
112- def __getstate__ (self ) -> Dict [Any , Any ]:
111+ def __getstate__ (self ) -> dict [Any , Any ]:
113112 state = self .__dict__ .copy ()
114113 del state ["_worker_id_prefix" ]
115114 del state ["_replay_result" ]
116115 del state ["_thread_lock" ]
117116 return state
118117
119- def __setstate__ (self , state : Dict [Any , Any ]) -> None :
118+ def __setstate__ (self , state : dict [Any , Any ]) -> None :
120119 self .__dict__ .update (state )
121120 self ._worker_id_prefix = str (uuid .uuid4 ()) + "-"
122121 self ._replay_result = JournalStorageReplayResult (self ._worker_id_prefix )
123122 self ._thread_lock = threading .Lock ()
124123
125124 def restore_replay_result (self , snapshot : bytes ) -> None :
126125 try :
127- r : Optional [ JournalStorageReplayResult ] = pickle .loads (snapshot )
126+ r : JournalStorageReplayResult | None = pickle .loads (snapshot )
128127 except (pickle .UnpicklingError , KeyError ):
129128 _logger .warning ("Failed to restore `JournalStorageReplayResult`." )
130129 return
@@ -138,7 +137,7 @@ def restore_replay_result(self, snapshot: bytes) -> None:
138137 r ._last_created_trial_id_by_this_process = - 1
139138 self ._replay_result = r
140139
141- def _write_log (self , op_code : int , extra_fields : Dict [str , Any ]) -> None :
140+ def _write_log (self , op_code : int , extra_fields : dict [str , Any ]) -> None :
142141 worker_id = self ._replay_result .worker_id
143142 self ._backend .append_logs ([{"op_code" : op_code , "worker_id" : worker_id , ** extra_fields }])
144143
@@ -147,7 +146,7 @@ def _sync_with_backend(self) -> None:
147146 self ._replay_result .apply_logs (logs )
148147
149148 def create_new_study (
150- self , directions : Sequence [StudyDirection ], study_name : Optional [ str ] = None
149+ self , directions : Sequence [StudyDirection ], study_name : str | None = None
151150 ) -> int :
152151 study_name = study_name or DEFAULT_STUDY_NAME_PREFIX + str (uuid .uuid4 ())
153152
@@ -181,13 +180,13 @@ def delete_study(self, study_id: int) -> None:
181180 self ._sync_with_backend ()
182181
183182 def set_study_user_attr (self , study_id : int , key : str , value : Any ) -> None :
184- log : Dict [str , Any ] = {"study_id" : study_id , "user_attr" : {key : value }}
183+ log : dict [str , Any ] = {"study_id" : study_id , "user_attr" : {key : value }}
185184 with self ._thread_lock :
186185 self ._write_log (JournalOperation .SET_STUDY_USER_ATTR , log )
187186 self ._sync_with_backend ()
188187
189188 def set_study_system_attr (self , study_id : int , key : str , value : JSONSerializable ) -> None :
190- log : Dict [str , Any ] = {"study_id" : study_id , "system_attr" : {key : value }}
189+ log : dict [str , Any ] = {"study_id" : study_id , "system_attr" : {key : value }}
191190 with self ._thread_lock :
192191 self ._write_log (JournalOperation .SET_STUDY_SYSTEM_ATTR , log )
193192 self ._sync_with_backend ()
@@ -205,29 +204,29 @@ def get_study_name_from_id(self, study_id: int) -> str:
205204 self ._sync_with_backend ()
206205 return self ._replay_result .get_study (study_id ).study_name
207206
208- def get_study_directions (self , study_id : int ) -> List [StudyDirection ]:
207+ def get_study_directions (self , study_id : int ) -> list [StudyDirection ]:
209208 with self ._thread_lock :
210209 self ._sync_with_backend ()
211210 return self ._replay_result .get_study (study_id ).directions
212211
213- def get_study_user_attrs (self , study_id : int ) -> Dict [str , Any ]:
212+ def get_study_user_attrs (self , study_id : int ) -> dict [str , Any ]:
214213 with self ._thread_lock :
215214 self ._sync_with_backend ()
216215 return self ._replay_result .get_study (study_id ).user_attrs
217216
218- def get_study_system_attrs (self , study_id : int ) -> Dict [str , Any ]:
217+ def get_study_system_attrs (self , study_id : int ) -> dict [str , Any ]:
219218 with self ._thread_lock :
220219 self ._sync_with_backend ()
221220 return self ._replay_result .get_study (study_id ).system_attrs
222221
223- def get_all_studies (self ) -> List [FrozenStudy ]:
222+ def get_all_studies (self ) -> list [FrozenStudy ]:
224223 with self ._thread_lock :
225224 self ._sync_with_backend ()
226225 return copy .deepcopy (self ._replay_result .get_all_studies ())
227226
228227 # Basic trial manipulation
229- def create_new_trial (self , study_id : int , template_trial : Optional [ FrozenTrial ] = None ) -> int :
230- log : Dict [str , Any ] = {
228+ def create_new_trial (self , study_id : int , template_trial : FrozenTrial | None = None ) -> int :
229+ log : dict [str , Any ] = {
231230 "study_id" : study_id ,
232231 "datetime_start" : datetime .datetime .now ().isoformat (timespec = "microseconds" ),
233232 }
@@ -283,7 +282,7 @@ def set_trial_param(
283282 param_value_internal : float ,
284283 distribution : BaseDistribution ,
285284 ) -> None :
286- log : Dict [str , Any ] = {
285+ log : dict [str , Any ] = {
287286 "trial_id" : trial_id ,
288287 "param_name" : param_name ,
289288 "param_value_internal" : param_value_internal ,
@@ -306,9 +305,9 @@ def get_trial_id_from_study_id_trial_number(self, study_id: int, trial_number: i
306305 return self ._replay_result ._study_id_to_trial_ids [study_id ][trial_number ]
307306
308307 def set_trial_state_values (
309- self , trial_id : int , state : TrialState , values : Optional [ Sequence [float ]] = None
308+ self , trial_id : int , state : TrialState , values : Sequence [float ] | None = None
310309 ) -> bool :
311- log : Dict [str , Any ] = {
310+ log : dict [str , Any ] = {
312311 "trial_id" : trial_id ,
313312 "state" : state ,
314313 "values" : values ,
@@ -331,7 +330,7 @@ def set_trial_state_values(
331330 def set_trial_intermediate_value (
332331 self , trial_id : int , step : int , intermediate_value : float
333332 ) -> None :
334- log : Dict [str , Any ] = {
333+ log : dict [str , Any ] = {
335334 "trial_id" : trial_id ,
336335 "step" : step ,
337336 "intermediate_value" : intermediate_value ,
@@ -342,7 +341,7 @@ def set_trial_intermediate_value(
342341 self ._sync_with_backend ()
343342
344343 def set_trial_user_attr (self , trial_id : int , key : str , value : Any ) -> None :
345- log : Dict [str , Any ] = {
344+ log : dict [str , Any ] = {
346345 "trial_id" : trial_id ,
347346 "user_attr" : {key : value },
348347 }
@@ -352,7 +351,7 @@ def set_trial_user_attr(self, trial_id: int, key: str, value: Any) -> None:
352351 self ._sync_with_backend ()
353352
354353 def set_trial_system_attr (self , trial_id : int , key : str , value : JSONSerializable ) -> None :
355- log : Dict [str , Any ] = {
354+ log : dict [str , Any ] = {
356355 "trial_id" : trial_id ,
357356 "system_attr" : {key : value },
358357 }
@@ -370,8 +369,8 @@ def get_all_trials(
370369 self ,
371370 study_id : int ,
372371 deepcopy : bool = True ,
373- states : Optional [ Container [TrialState ]] = None ,
374- ) -> List [FrozenTrial ]:
372+ states : Container [TrialState ] | None = None ,
373+ ) -> list [FrozenTrial ]:
375374 with self ._thread_lock :
376375 self ._sync_with_backend ()
377376 frozen_trials = self ._replay_result .get_all_trials (study_id , states )
@@ -384,15 +383,15 @@ class JournalStorageReplayResult:
384383 def __init__ (self , worker_id_prefix : str ) -> None :
385384 self .log_number_read = 0
386385 self ._worker_id_prefix = worker_id_prefix
387- self ._studies : Dict [int , FrozenStudy ] = {}
388- self ._trials : Dict [int , FrozenTrial ] = {}
386+ self ._studies : dict [int , FrozenStudy ] = {}
387+ self ._trials : dict [int , FrozenTrial ] = {}
389388
390- self ._study_id_to_trial_ids : Dict [int , List [int ]] = {}
391- self ._trial_id_to_study_id : Dict [int , int ] = {}
389+ self ._study_id_to_trial_ids : dict [int , list [int ]] = {}
390+ self ._trial_id_to_study_id : dict [int , int ] = {}
392391 self ._next_study_id : int = 0
393- self ._worker_id_to_owned_trial_id : Dict [str , int ] = {}
392+ self ._worker_id_to_owned_trial_id : dict [str , int ] = {}
394393
395- def apply_logs (self , logs : List [ Dict [str , Any ]]) -> None :
394+ def apply_logs (self , logs : list [ dict [str , Any ]]) -> None :
396395 for log in logs :
397396 self .log_number_read += 1
398397 op = log ["op_code" ]
@@ -424,7 +423,7 @@ def get_study(self, study_id: int) -> FrozenStudy:
424423 raise KeyError (NOT_FOUND_MSG )
425424 return self ._studies [study_id ]
426425
427- def get_all_studies (self ) -> List [FrozenStudy ]:
426+ def get_all_studies (self ) -> list [FrozenStudy ]:
428427 return list (self ._studies .values ())
429428
430429 def get_trial (self , trial_id : int ) -> FrozenTrial :
@@ -433,12 +432,12 @@ def get_trial(self, trial_id: int) -> FrozenTrial:
433432 return self ._trials [trial_id ]
434433
435434 def get_all_trials (
436- self , study_id : int , states : Optional [ Container [TrialState ]]
437- ) -> List [FrozenTrial ]:
435+ self , study_id : int , states : Container [TrialState ] | None
436+ ) -> list [FrozenTrial ]:
438437 if study_id not in self ._studies :
439438 raise KeyError (NOT_FOUND_MSG )
440439
441- frozen_trials : List [FrozenTrial ] = []
440+ frozen_trials : list [FrozenTrial ] = []
442441 for trial_id in self ._study_id_to_trial_ids [study_id ]:
443442 trial = self ._trials [trial_id ]
444443 if states is None or trial .state in states :
@@ -450,20 +449,20 @@ def worker_id(self) -> str:
450449 return self ._worker_id_prefix + str (threading .get_ident ())
451450
452451 @property
453- def owned_trial_id (self ) -> Optional [ int ] :
452+ def owned_trial_id (self ) -> int | None :
454453 return self ._worker_id_to_owned_trial_id .get (self .worker_id )
455454
456- def _is_issued_by_this_worker (self , log : Dict [str , Any ]) -> bool :
455+ def _is_issued_by_this_worker (self , log : dict [str , Any ]) -> bool :
457456 return log ["worker_id" ] == self .worker_id
458457
459- def _study_exists (self , study_id : int , log : Dict [str , Any ]) -> bool :
458+ def _study_exists (self , study_id : int , log : dict [str , Any ]) -> bool :
460459 if study_id in self ._studies :
461460 return True
462461 if self ._is_issued_by_this_worker (log ):
463462 raise KeyError (NOT_FOUND_MSG )
464463 return False
465464
466- def _apply_create_study (self , log : Dict [str , Any ]) -> None :
465+ def _apply_create_study (self , log : dict [str , Any ]) -> None :
467466 study_name = log ["study_name" ]
468467 directions = [StudyDirection (d ) for d in log ["directions" ]]
469468
@@ -490,28 +489,28 @@ def _apply_create_study(self, log: Dict[str, Any]) -> None:
490489 )
491490 self ._study_id_to_trial_ids [study_id ] = []
492491
493- def _apply_delete_study (self , log : Dict [str , Any ]) -> None :
492+ def _apply_delete_study (self , log : dict [str , Any ]) -> None :
494493 study_id = log ["study_id" ]
495494
496495 if self ._study_exists (study_id , log ):
497496 fs = self ._studies .pop (study_id )
498497 assert fs ._study_id == study_id
499498
500- def _apply_set_study_user_attr (self , log : Dict [str , Any ]) -> None :
499+ def _apply_set_study_user_attr (self , log : dict [str , Any ]) -> None :
501500 study_id = log ["study_id" ]
502501
503502 if self ._study_exists (study_id , log ):
504503 assert len (log ["user_attr" ]) == 1
505504 self ._studies [study_id ].user_attrs .update (log ["user_attr" ])
506505
507- def _apply_set_study_system_attr (self , log : Dict [str , Any ]) -> None :
506+ def _apply_set_study_system_attr (self , log : dict [str , Any ]) -> None :
508507 study_id = log ["study_id" ]
509508
510509 if self ._study_exists (study_id , log ):
511510 assert len (log ["system_attr" ]) == 1
512511 self ._studies [study_id ].system_attrs .update (log ["system_attr" ])
513512
514- def _apply_create_trial (self , log : Dict [str , Any ]) -> None :
513+ def _apply_create_trial (self , log : dict [str , Any ]) -> None :
515514 study_id = log ["study_id" ]
516515
517516 if not self ._study_exists (study_id , log ):
@@ -556,7 +555,7 @@ def _apply_create_trial(self, log: Dict[str, Any]) -> None:
556555 if self ._trials [trial_id ].state == TrialState .RUNNING :
557556 self ._worker_id_to_owned_trial_id [self .worker_id ] = trial_id
558557
559- def _apply_set_trial_param (self , log : Dict [str , Any ]) -> None :
558+ def _apply_set_trial_param (self , log : dict [str , Any ]) -> None :
560559 trial_id = log ["trial_id" ]
561560
562561 if not self ._trial_exists_and_updatable (trial_id , log ):
@@ -589,7 +588,7 @@ def _apply_set_trial_param(self, log: Dict[str, Any]) -> None:
589588 trial .distributions = {** copy .copy (trial .distributions ), param_name : distribution }
590589 self ._trials [trial_id ] = trial
591590
592- def _apply_set_trial_state_values (self , log : Dict [str , Any ]) -> None :
591+ def _apply_set_trial_state_values (self , log : dict [str , Any ]) -> None :
593592 trial_id = log ["trial_id" ]
594593
595594 if not self ._trial_exists_and_updatable (trial_id , log ):
@@ -612,7 +611,7 @@ def _apply_set_trial_state_values(self, log: Dict[str, Any]) -> None:
612611
613612 self ._trials [trial_id ] = trial
614613
615- def _apply_set_trial_intermediate_value (self , log : Dict [str , Any ]) -> None :
614+ def _apply_set_trial_intermediate_value (self , log : dict [str , Any ]) -> None :
616615 trial_id = log ["trial_id" ]
617616
618617 if self ._trial_exists_and_updatable (trial_id , log ):
@@ -623,7 +622,7 @@ def _apply_set_trial_intermediate_value(self, log: Dict[str, Any]) -> None:
623622 }
624623 self ._trials [trial_id ] = trial
625624
626- def _apply_set_trial_user_attr (self , log : Dict [str , Any ]) -> None :
625+ def _apply_set_trial_user_attr (self , log : dict [str , Any ]) -> None :
627626 trial_id = log ["trial_id" ]
628627
629628 if self ._trial_exists_and_updatable (trial_id , log ):
@@ -632,7 +631,7 @@ def _apply_set_trial_user_attr(self, log: Dict[str, Any]) -> None:
632631 trial .user_attrs = {** copy .copy (trial .user_attrs ), ** log ["user_attr" ]}
633632 self ._trials [trial_id ] = trial
634633
635- def _apply_set_trial_system_attr (self , log : Dict [str , Any ]) -> None :
634+ def _apply_set_trial_system_attr (self , log : dict [str , Any ]) -> None :
636635 trial_id = log ["trial_id" ]
637636
638637 if self ._trial_exists_and_updatable (trial_id , log ):
@@ -644,7 +643,7 @@ def _apply_set_trial_system_attr(self, log: Dict[str, Any]) -> None:
644643 }
645644 self ._trials [trial_id ] = trial
646645
647- def _trial_exists_and_updatable (self , trial_id : int , log : Dict [str , Any ]) -> bool :
646+ def _trial_exists_and_updatable (self , trial_id : int , log : dict [str , Any ]) -> bool :
648647 if trial_id not in self ._trials :
649648 if self ._is_issued_by_this_worker (log ):
650649 raise KeyError (NOT_FOUND_MSG )
0 commit comments