Skip to content

Commit faf7e85

Browse files
authored
Merge pull request optuna#6004 from fusawa-yugo/fusawa-yugo/survey-error-in-grpc-server
Fix a bug that a gRPC server doesn't work with JournalStorage
2 parents d5d3eef + a73df81 commit faf7e85

File tree

4 files changed

+29
-6
lines changed

4 files changed

+29
-6
lines changed

optuna/storages/_grpc/servicer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def SetTrialStateValues(
249249
) -> api_pb2.SetTrialStateValuesReply:
250250
trial_id = request.trial_id
251251
state = request.state
252-
values = request.values
252+
values = list(request.values)
253253
try:
254254
trial_updated = self._backend.set_trial_state_values(
255255
trial_id, _from_proto_trial_state(state), values

optuna/testing/storages.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@
3030
"cached_sqlite",
3131
"journal",
3232
"journal_redis",
33-
"grpc",
33+
"grpc_rdb",
34+
"grpc_journal_file",
3435
]
3536

3637

@@ -84,12 +85,34 @@ def __enter__(
8485
"redis", fakeredis.FakeStrictRedis() # type: ignore[no-untyped-call]
8586
)
8687
return optuna.storages.JournalStorage(journal_redis_storage)
88+
elif self.storage_specifier == "grpc_journal_file":
89+
self.tempfile = self.extra_args.get("file", NamedTemporaryFilePool().tempfile())
90+
assert self.tempfile is not None
91+
port = _find_free_port()
92+
storage = optuna.storages.JournalStorage(
93+
optuna.storages.journal.JournalFileBackend(self.tempfile.name)
94+
)
95+
96+
self.server = optuna.storages._grpc.server.make_server(storage, "localhost", port)
97+
self.thread = threading.Thread(target=self.server.start)
98+
self.thread.start()
99+
100+
proxy = GrpcStorageProxy(host="localhost", port=port)
101+
102+
# Wait until the server is ready.
103+
while True:
104+
try:
105+
proxy.get_all_studies()
106+
return proxy
107+
except grpc.RpcError:
108+
time.sleep(1)
109+
continue
87110
elif "journal" in self.storage_specifier:
88111
self.tempfile = self.extra_args.get("file", NamedTemporaryFilePool().tempfile())
89112
assert self.tempfile is not None
90113
file_storage = JournalFileBackend(self.tempfile.name)
91114
return optuna.storages.JournalStorage(file_storage)
92-
elif self.storage_specifier == "grpc":
115+
elif self.storage_specifier == "grpc_rdb":
93116
self.tempfile = NamedTemporaryFilePool().tempfile()
94117
url = "sqlite:///{}".format(self.tempfile.name)
95118
port = _find_free_port()

tests/storages_tests/test_storages.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -773,7 +773,7 @@ def test_get_all_trials(storage_mode: str) -> None:
773773
def test_get_all_trials_params_order(storage_mode: str, param_names: list[str]) -> None:
774774
# We don't actually require that all storages to preserve the order of parameters,
775775
# but all current implementations except for GrpcStorageProxy do, so we test this property.
776-
if storage_mode == "grpc":
776+
if storage_mode in ("grpc_rdb", "grpc_journal_file"):
777777
pytest.skip("GrpcStorageProxy does not preserve the order of parameters.")
778778

779779
with StorageSupplier(storage_mode) as storage:

tests/study_tests/test_study.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,7 +1036,7 @@ def terminate_study(study: Study, trial: FrozenTrial) -> None:
10361036

10371037
@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
10381038
def test_get_trials(storage_mode: str) -> None:
1039-
if storage_mode == "grpc":
1039+
if storage_mode in ("grpc_rdb", "grpc_journal_file"):
10401040
pytest.skip("gRPC storage doesn't use `copy.deepcopy`.")
10411041

10421042
with StorageSupplier(storage_mode) as storage:
@@ -1651,7 +1651,7 @@ def test_tell_from_another_process() -> None:
16511651

16521652
@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
16531653
def test_pop_waiting_trial_thread_safe(storage_mode: str) -> None:
1654-
if "sqlite" == storage_mode or "cached_sqlite" == storage_mode or "grpc" == storage_mode:
1654+
if storage_mode in ("sqlite", "cached_sqlite", "grpc_rdb"):
16551655
pytest.skip("study._pop_waiting_trial is not thread-safe on SQLite3")
16561656

16571657
num_enqueued = 10

0 commit comments

Comments
 (0)