Skip to content

Commit c2e0b3b

Browse files
authored
Merge pull request optuna#5709 from porink0424/fix/remove-unnecessary-distribution-compatibility-check
Reduce `SELECT` statements by removing unnecessary distribution compatibility check in `set_trial_param()`
2 parents 74e3618 + 9a04ada commit c2e0b3b

File tree

2 files changed

+6
-24
lines changed

2 files changed

+6
-24
lines changed

optuna/storages/_rdb/storage.py

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -588,27 +588,14 @@ def _set_trial_param_without_commit(
588588
trial = models.TrialModel.find_or_raise_by_id(trial_id, session)
589589
self.check_trial_is_updatable(trial_id, trial.state)
590590

591-
trial_param = models.TrialParamModel.find_by_trial_and_param_name(
592-
trial, param_name, session
591+
trial_param = models.TrialParamModel(
592+
trial_id=trial_id,
593+
param_name=param_name,
594+
param_value=param_value_internal,
595+
distribution_json=distributions.distribution_to_json(distribution),
593596
)
594597

595-
if trial_param is not None:
596-
# Raise error in case distribution is incompatible.
597-
distributions.check_distribution_compatibility(
598-
distributions.json_to_distribution(trial_param.distribution_json), distribution
599-
)
600-
601-
trial_param.param_value = param_value_internal
602-
trial_param.distribution_json = distributions.distribution_to_json(distribution)
603-
else:
604-
trial_param = models.TrialParamModel(
605-
trial_id=trial_id,
606-
param_name=param_name,
607-
param_value=param_value_internal,
608-
distribution_json=distributions.distribution_to_json(distribution),
609-
)
610-
611-
trial_param.check_and_add(session)
598+
trial_param.check_and_add(session)
612599

613600
def _check_and_set_param_distribution(
614601
self,

tests/storages_tests/test_storages.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -486,11 +486,6 @@ def test_set_trial_param(storage_mode: str) -> None:
486486
# Check set_param breaks neither get_trial nor get_trial_params.
487487
assert storage.get_trial(trial_id_1).params == {"x": 0.5, "y": "Meguro"}
488488
assert storage.get_trial_params(trial_id_1) == {"x": 0.5, "y": "Meguro"}
489-
# Duplicated registration should overwrite.
490-
storage.set_trial_param(trial_id_1, "x", 0.6, distribution_x)
491-
assert storage.get_trial_param(trial_id_1, "x") == 0.6
492-
assert storage.get_trial(trial_id_1).params == {"x": 0.6, "y": "Meguro"}
493-
assert storage.get_trial_params(trial_id_1) == {"x": 0.6, "y": "Meguro"}
494489

495490
# Set params to another trial.
496491
storage.set_trial_param(trial_id_2, "x", 0.3, distribution_x)

0 commit comments

Comments
 (0)