Skip to content

Commit a92b168

Browse files
authored
Merge pull request optuna#5841 from boringbyte/fix/_storage.py
Simplify type annotations for `optuna/storages/journal/_storage.py`
2 parents e01efa0 + 8e8ef6e commit a92b168

File tree

1 file changed

+49
-50
lines changed

1 file changed

+49
-50
lines changed

optuna/storages/journal/_storage.py

Lines changed: 49 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
1+
from __future__ import annotations
2+
3+
from collections.abc import Container
4+
from collections.abc import Sequence
15
import copy
26
import datetime
37
import enum
48
import pickle
59
import threading
610
from typing import Any
7-
from typing import Container
8-
from typing import Dict
9-
from typing import List
10-
from typing import Optional
11-
from typing import Sequence
1211
import uuid
1312

1413
import optuna
@@ -109,22 +108,22 @@ def __init__(self, log_storage: BaseJournalBackend) -> None:
109108
self.restore_replay_result(snapshot)
110109
self._sync_with_backend()
111110

112-
def __getstate__(self) -> Dict[Any, Any]:
111+
def __getstate__(self) -> dict[Any, Any]:
113112
state = self.__dict__.copy()
114113
del state["_worker_id_prefix"]
115114
del state["_replay_result"]
116115
del state["_thread_lock"]
117116
return state
118117

119-
def __setstate__(self, state: Dict[Any, Any]) -> None:
118+
def __setstate__(self, state: dict[Any, Any]) -> None:
120119
self.__dict__.update(state)
121120
self._worker_id_prefix = str(uuid.uuid4()) + "-"
122121
self._replay_result = JournalStorageReplayResult(self._worker_id_prefix)
123122
self._thread_lock = threading.Lock()
124123

125124
def restore_replay_result(self, snapshot: bytes) -> None:
126125
try:
127-
r: Optional[JournalStorageReplayResult] = pickle.loads(snapshot)
126+
r: JournalStorageReplayResult | None = pickle.loads(snapshot)
128127
except (pickle.UnpicklingError, KeyError):
129128
_logger.warning("Failed to restore `JournalStorageReplayResult`.")
130129
return
@@ -138,7 +137,7 @@ def restore_replay_result(self, snapshot: bytes) -> None:
138137
r._last_created_trial_id_by_this_process = -1
139138
self._replay_result = r
140139

141-
def _write_log(self, op_code: int, extra_fields: Dict[str, Any]) -> None:
140+
def _write_log(self, op_code: int, extra_fields: dict[str, Any]) -> None:
142141
worker_id = self._replay_result.worker_id
143142
self._backend.append_logs([{"op_code": op_code, "worker_id": worker_id, **extra_fields}])
144143

@@ -147,7 +146,7 @@ def _sync_with_backend(self) -> None:
147146
self._replay_result.apply_logs(logs)
148147

149148
def create_new_study(
150-
self, directions: Sequence[StudyDirection], study_name: Optional[str] = None
149+
self, directions: Sequence[StudyDirection], study_name: str | None = None
151150
) -> int:
152151
study_name = study_name or DEFAULT_STUDY_NAME_PREFIX + str(uuid.uuid4())
153152

@@ -181,13 +180,13 @@ def delete_study(self, study_id: int) -> None:
181180
self._sync_with_backend()
182181

183182
def set_study_user_attr(self, study_id: int, key: str, value: Any) -> None:
184-
log: Dict[str, Any] = {"study_id": study_id, "user_attr": {key: value}}
183+
log: dict[str, Any] = {"study_id": study_id, "user_attr": {key: value}}
185184
with self._thread_lock:
186185
self._write_log(JournalOperation.SET_STUDY_USER_ATTR, log)
187186
self._sync_with_backend()
188187

189188
def set_study_system_attr(self, study_id: int, key: str, value: JSONSerializable) -> None:
190-
log: Dict[str, Any] = {"study_id": study_id, "system_attr": {key: value}}
189+
log: dict[str, Any] = {"study_id": study_id, "system_attr": {key: value}}
191190
with self._thread_lock:
192191
self._write_log(JournalOperation.SET_STUDY_SYSTEM_ATTR, log)
193192
self._sync_with_backend()
@@ -205,29 +204,29 @@ def get_study_name_from_id(self, study_id: int) -> str:
205204
self._sync_with_backend()
206205
return self._replay_result.get_study(study_id).study_name
207206

208-
def get_study_directions(self, study_id: int) -> List[StudyDirection]:
207+
def get_study_directions(self, study_id: int) -> list[StudyDirection]:
209208
with self._thread_lock:
210209
self._sync_with_backend()
211210
return self._replay_result.get_study(study_id).directions
212211

213-
def get_study_user_attrs(self, study_id: int) -> Dict[str, Any]:
212+
def get_study_user_attrs(self, study_id: int) -> dict[str, Any]:
214213
with self._thread_lock:
215214
self._sync_with_backend()
216215
return self._replay_result.get_study(study_id).user_attrs
217216

218-
def get_study_system_attrs(self, study_id: int) -> Dict[str, Any]:
217+
def get_study_system_attrs(self, study_id: int) -> dict[str, Any]:
219218
with self._thread_lock:
220219
self._sync_with_backend()
221220
return self._replay_result.get_study(study_id).system_attrs
222221

223-
def get_all_studies(self) -> List[FrozenStudy]:
222+
def get_all_studies(self) -> list[FrozenStudy]:
224223
with self._thread_lock:
225224
self._sync_with_backend()
226225
return copy.deepcopy(self._replay_result.get_all_studies())
227226

228227
# Basic trial manipulation
229-
def create_new_trial(self, study_id: int, template_trial: Optional[FrozenTrial] = None) -> int:
230-
log: Dict[str, Any] = {
228+
def create_new_trial(self, study_id: int, template_trial: FrozenTrial | None = None) -> int:
229+
log: dict[str, Any] = {
231230
"study_id": study_id,
232231
"datetime_start": datetime.datetime.now().isoformat(timespec="microseconds"),
233232
}
@@ -283,7 +282,7 @@ def set_trial_param(
283282
param_value_internal: float,
284283
distribution: BaseDistribution,
285284
) -> None:
286-
log: Dict[str, Any] = {
285+
log: dict[str, Any] = {
287286
"trial_id": trial_id,
288287
"param_name": param_name,
289288
"param_value_internal": param_value_internal,
@@ -306,9 +305,9 @@ def get_trial_id_from_study_id_trial_number(self, study_id: int, trial_number: i
306305
return self._replay_result._study_id_to_trial_ids[study_id][trial_number]
307306

308307
def set_trial_state_values(
309-
self, trial_id: int, state: TrialState, values: Optional[Sequence[float]] = None
308+
self, trial_id: int, state: TrialState, values: Sequence[float] | None = None
310309
) -> bool:
311-
log: Dict[str, Any] = {
310+
log: dict[str, Any] = {
312311
"trial_id": trial_id,
313312
"state": state,
314313
"values": values,
@@ -331,7 +330,7 @@ def set_trial_state_values(
331330
def set_trial_intermediate_value(
332331
self, trial_id: int, step: int, intermediate_value: float
333332
) -> None:
334-
log: Dict[str, Any] = {
333+
log: dict[str, Any] = {
335334
"trial_id": trial_id,
336335
"step": step,
337336
"intermediate_value": intermediate_value,
@@ -342,7 +341,7 @@ def set_trial_intermediate_value(
342341
self._sync_with_backend()
343342

344343
def set_trial_user_attr(self, trial_id: int, key: str, value: Any) -> None:
345-
log: Dict[str, Any] = {
344+
log: dict[str, Any] = {
346345
"trial_id": trial_id,
347346
"user_attr": {key: value},
348347
}
@@ -352,7 +351,7 @@ def set_trial_user_attr(self, trial_id: int, key: str, value: Any) -> None:
352351
self._sync_with_backend()
353352

354353
def set_trial_system_attr(self, trial_id: int, key: str, value: JSONSerializable) -> None:
355-
log: Dict[str, Any] = {
354+
log: dict[str, Any] = {
356355
"trial_id": trial_id,
357356
"system_attr": {key: value},
358357
}
@@ -370,8 +369,8 @@ def get_all_trials(
370369
self,
371370
study_id: int,
372371
deepcopy: bool = True,
373-
states: Optional[Container[TrialState]] = None,
374-
) -> List[FrozenTrial]:
372+
states: Container[TrialState] | None = None,
373+
) -> list[FrozenTrial]:
375374
with self._thread_lock:
376375
self._sync_with_backend()
377376
frozen_trials = self._replay_result.get_all_trials(study_id, states)
@@ -384,15 +383,15 @@ class JournalStorageReplayResult:
384383
def __init__(self, worker_id_prefix: str) -> None:
385384
self.log_number_read = 0
386385
self._worker_id_prefix = worker_id_prefix
387-
self._studies: Dict[int, FrozenStudy] = {}
388-
self._trials: Dict[int, FrozenTrial] = {}
386+
self._studies: dict[int, FrozenStudy] = {}
387+
self._trials: dict[int, FrozenTrial] = {}
389388

390-
self._study_id_to_trial_ids: Dict[int, List[int]] = {}
391-
self._trial_id_to_study_id: Dict[int, int] = {}
389+
self._study_id_to_trial_ids: dict[int, list[int]] = {}
390+
self._trial_id_to_study_id: dict[int, int] = {}
392391
self._next_study_id: int = 0
393-
self._worker_id_to_owned_trial_id: Dict[str, int] = {}
392+
self._worker_id_to_owned_trial_id: dict[str, int] = {}
394393

395-
def apply_logs(self, logs: List[Dict[str, Any]]) -> None:
394+
def apply_logs(self, logs: list[dict[str, Any]]) -> None:
396395
for log in logs:
397396
self.log_number_read += 1
398397
op = log["op_code"]
@@ -424,7 +423,7 @@ def get_study(self, study_id: int) -> FrozenStudy:
424423
raise KeyError(NOT_FOUND_MSG)
425424
return self._studies[study_id]
426425

427-
def get_all_studies(self) -> List[FrozenStudy]:
426+
def get_all_studies(self) -> list[FrozenStudy]:
428427
return list(self._studies.values())
429428

430429
def get_trial(self, trial_id: int) -> FrozenTrial:
@@ -433,12 +432,12 @@ def get_trial(self, trial_id: int) -> FrozenTrial:
433432
return self._trials[trial_id]
434433

435434
def get_all_trials(
436-
self, study_id: int, states: Optional[Container[TrialState]]
437-
) -> List[FrozenTrial]:
435+
self, study_id: int, states: Container[TrialState] | None
436+
) -> list[FrozenTrial]:
438437
if study_id not in self._studies:
439438
raise KeyError(NOT_FOUND_MSG)
440439

441-
frozen_trials: List[FrozenTrial] = []
440+
frozen_trials: list[FrozenTrial] = []
442441
for trial_id in self._study_id_to_trial_ids[study_id]:
443442
trial = self._trials[trial_id]
444443
if states is None or trial.state in states:
@@ -450,20 +449,20 @@ def worker_id(self) -> str:
450449
return self._worker_id_prefix + str(threading.get_ident())
451450

452451
@property
453-
def owned_trial_id(self) -> Optional[int]:
452+
def owned_trial_id(self) -> int | None:
454453
return self._worker_id_to_owned_trial_id.get(self.worker_id)
455454

456-
def _is_issued_by_this_worker(self, log: Dict[str, Any]) -> bool:
455+
def _is_issued_by_this_worker(self, log: dict[str, Any]) -> bool:
457456
return log["worker_id"] == self.worker_id
458457

459-
def _study_exists(self, study_id: int, log: Dict[str, Any]) -> bool:
458+
def _study_exists(self, study_id: int, log: dict[str, Any]) -> bool:
460459
if study_id in self._studies:
461460
return True
462461
if self._is_issued_by_this_worker(log):
463462
raise KeyError(NOT_FOUND_MSG)
464463
return False
465464

466-
def _apply_create_study(self, log: Dict[str, Any]) -> None:
465+
def _apply_create_study(self, log: dict[str, Any]) -> None:
467466
study_name = log["study_name"]
468467
directions = [StudyDirection(d) for d in log["directions"]]
469468

@@ -490,28 +489,28 @@ def _apply_create_study(self, log: Dict[str, Any]) -> None:
490489
)
491490
self._study_id_to_trial_ids[study_id] = []
492491

493-
def _apply_delete_study(self, log: Dict[str, Any]) -> None:
492+
def _apply_delete_study(self, log: dict[str, Any]) -> None:
494493
study_id = log["study_id"]
495494

496495
if self._study_exists(study_id, log):
497496
fs = self._studies.pop(study_id)
498497
assert fs._study_id == study_id
499498

500-
def _apply_set_study_user_attr(self, log: Dict[str, Any]) -> None:
499+
def _apply_set_study_user_attr(self, log: dict[str, Any]) -> None:
501500
study_id = log["study_id"]
502501

503502
if self._study_exists(study_id, log):
504503
assert len(log["user_attr"]) == 1
505504
self._studies[study_id].user_attrs.update(log["user_attr"])
506505

507-
def _apply_set_study_system_attr(self, log: Dict[str, Any]) -> None:
506+
def _apply_set_study_system_attr(self, log: dict[str, Any]) -> None:
508507
study_id = log["study_id"]
509508

510509
if self._study_exists(study_id, log):
511510
assert len(log["system_attr"]) == 1
512511
self._studies[study_id].system_attrs.update(log["system_attr"])
513512

514-
def _apply_create_trial(self, log: Dict[str, Any]) -> None:
513+
def _apply_create_trial(self, log: dict[str, Any]) -> None:
515514
study_id = log["study_id"]
516515

517516
if not self._study_exists(study_id, log):
@@ -556,7 +555,7 @@ def _apply_create_trial(self, log: Dict[str, Any]) -> None:
556555
if self._trials[trial_id].state == TrialState.RUNNING:
557556
self._worker_id_to_owned_trial_id[self.worker_id] = trial_id
558557

559-
def _apply_set_trial_param(self, log: Dict[str, Any]) -> None:
558+
def _apply_set_trial_param(self, log: dict[str, Any]) -> None:
560559
trial_id = log["trial_id"]
561560

562561
if not self._trial_exists_and_updatable(trial_id, log):
@@ -589,7 +588,7 @@ def _apply_set_trial_param(self, log: Dict[str, Any]) -> None:
589588
trial.distributions = {**copy.copy(trial.distributions), param_name: distribution}
590589
self._trials[trial_id] = trial
591590

592-
def _apply_set_trial_state_values(self, log: Dict[str, Any]) -> None:
591+
def _apply_set_trial_state_values(self, log: dict[str, Any]) -> None:
593592
trial_id = log["trial_id"]
594593

595594
if not self._trial_exists_and_updatable(trial_id, log):
@@ -612,7 +611,7 @@ def _apply_set_trial_state_values(self, log: Dict[str, Any]) -> None:
612611

613612
self._trials[trial_id] = trial
614613

615-
def _apply_set_trial_intermediate_value(self, log: Dict[str, Any]) -> None:
614+
def _apply_set_trial_intermediate_value(self, log: dict[str, Any]) -> None:
616615
trial_id = log["trial_id"]
617616

618617
if self._trial_exists_and_updatable(trial_id, log):
@@ -623,7 +622,7 @@ def _apply_set_trial_intermediate_value(self, log: Dict[str, Any]) -> None:
623622
}
624623
self._trials[trial_id] = trial
625624

626-
def _apply_set_trial_user_attr(self, log: Dict[str, Any]) -> None:
625+
def _apply_set_trial_user_attr(self, log: dict[str, Any]) -> None:
627626
trial_id = log["trial_id"]
628627

629628
if self._trial_exists_and_updatable(trial_id, log):
@@ -632,7 +631,7 @@ def _apply_set_trial_user_attr(self, log: Dict[str, Any]) -> None:
632631
trial.user_attrs = {**copy.copy(trial.user_attrs), **log["user_attr"]}
633632
self._trials[trial_id] = trial
634633

635-
def _apply_set_trial_system_attr(self, log: Dict[str, Any]) -> None:
634+
def _apply_set_trial_system_attr(self, log: dict[str, Any]) -> None:
636635
trial_id = log["trial_id"]
637636

638637
if self._trial_exists_and_updatable(trial_id, log):
@@ -644,7 +643,7 @@ def _apply_set_trial_system_attr(self, log: Dict[str, Any]) -> None:
644643
}
645644
self._trials[trial_id] = trial
646645

647-
def _trial_exists_and_updatable(self, trial_id: int, log: Dict[str, Any]) -> bool:
646+
def _trial_exists_and_updatable(self, trial_id: int, log: dict[str, Any]) -> bool:
648647
if trial_id not in self._trials:
649648
if self._is_issued_by_this_worker(log):
650649
raise KeyError(NOT_FOUND_MSG)

0 commit comments

Comments
 (0)