Skip to content

Commit ea3271e

Browse files
authored
Add preferential optimization picture (#327)
* add prederential-optimization-picture * add README * add blank line * fix import * add CLI argument for image * delete shell file * add blankline * formatted * fix launch command * fix small
1 parent 1469719 commit ea3271e

File tree

3 files changed

+103
-0
lines changed

3 files changed

+103
-0
lines changed
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# How to Run Preferential Optimization Image
2+
3+
First, ensure the necessary packages are installed by executing the following command in your terminal:
4+
5+
```bash
6+
$ pip install "optuna>=3.3.0" "optuna-dashboard[preferential]>=0.13.0b1" pillow
7+
```
8+
9+
Next, execute the Python script.
10+
11+
```bash
12+
$ python generator.py --image_path sample.png
13+
```
14+
15+
Then, launch Optuna Dashboard in a separate process using the following command.
16+
17+
```bash
18+
optuna-dashboard sqlite:///db.sqlite3 --artifact-dir ./artifact
19+
```
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
from __future__ import annotations
2+
3+
import argparse
4+
import os
5+
import tempfile
6+
import time
7+
from typing import NoReturn
8+
9+
from optuna.artifacts import FileSystemArtifactStore
10+
from optuna.artifacts import upload_artifact
11+
from optuna_dashboard import register_preference_feedback_component
12+
from optuna_dashboard.preferential import create_study
13+
from optuna_dashboard.preferential.samplers.gp import PreferentialGPSampler
14+
from PIL import Image
15+
from PIL import ImageEnhance
16+
17+
18+
STORAGE_URL = "sqlite:///db.sqlite3"
19+
artifact_path = os.path.join(os.path.dirname(__file__), "artifact")
20+
artifact_store = FileSystemArtifactStore(base_path=artifact_path)
21+
os.makedirs(artifact_path, exist_ok=True)
22+
23+
24+
def main() -> NoReturn:
25+
# Parse command-line arguments.
26+
parser = argparse.ArgumentParser(description="Optimize image enhancement parameters.")
27+
parser.add_argument(
28+
"--image_path", type=str, required=True, help="Path to the input image file."
29+
)
30+
args = parser.parse_args()
31+
32+
# Validate the image path.
33+
if not os.path.exists(args.image_path):
34+
raise FileNotFoundError(f"The specified image file does not exist: {args.image_path}")
35+
36+
study = create_study(
37+
n_generate=4,
38+
study_name="Preferential Optimization Image Scene",
39+
storage=STORAGE_URL,
40+
sampler=PreferentialGPSampler(),
41+
load_if_exists=True,
42+
)
43+
# Change the component, displayed on the human feedback pages.
44+
# By default (component_type="note"), the Trial's Markdown note is displayed.
45+
user_attr_key = "rgb_image"
46+
register_preference_feedback_component(study, "artifact", user_attr_key)
47+
image_sample = Image.open(args.image_path) # Use the image path from command-line arguments.
48+
with tempfile.TemporaryDirectory() as tmpdir:
49+
while True:
50+
# If study.should_generate() returns False,
51+
# the generator waits for human evaluation.
52+
if not study.should_generate():
53+
time.sleep(0.1) # Avoid busy-loop.
54+
continue
55+
56+
trial = study.ask()
57+
# 1. Ask new parameters.
58+
contrast_factor = trial.suggest_float("contrast_factor", 0.0, 2.0)
59+
brightness_factor = trial.suggest_float("brightness_factor", 0.0, 2.0)
60+
color_factor = trial.suggest_float("color_factor", 0.0, 2.0)
61+
sharpness_factor = trial.suggest_float("sharpness_factor", 0.0, 2.0)
62+
63+
# 2. Generate image.
64+
image_path = os.path.join(tmpdir, f"sample-{trial.number}.png")
65+
image = image_sample.copy()
66+
67+
image = ImageEnhance.Contrast(image).enhance(contrast_factor)
68+
image = ImageEnhance.Brightness(image).enhance(brightness_factor)
69+
image = ImageEnhance.Color(image).enhance(color_factor)
70+
image = ImageEnhance.Sharpness(image).enhance(sharpness_factor)
71+
72+
image.save(image_path)
73+
74+
# 3. Upload Artifact and set artifact_id to trial.user_attrs["rgb_image"].
75+
artifact_id = upload_artifact(
76+
artifact_store=artifact_store,
77+
file_path=image_path,
78+
study_or_trial=trial,
79+
)
80+
trial.set_user_attr(user_attr_key, artifact_id)
81+
82+
83+
if __name__ == "__main__":
84+
main()
982 KB
Loading

0 commit comments

Comments
 (0)