|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import os |
| 4 | + |
| 5 | +import matplotlib.pyplot as plt |
| 6 | +from mcp.server.fastmcp import FastMCP |
| 7 | +import numpy as np |
| 8 | + |
| 9 | + |
| 10 | +mcp = FastMCP("AutoMatplotlib") |
| 11 | +plt.rcParams["font.family"] = "Times New Roman" |
| 12 | +plt.rcParams["font.size"] = 16 |
| 13 | +rng = np.random.RandomState(42) |
| 14 | +X1 = rng.random(size=(3, 100, 40)) * 10 - 5 |
| 15 | +X2 = np.clip(rng.normal(size=(3, 100, 40)) * 2.5, -5, 5) |
| 16 | +os.makedirs("figs/", exist_ok=True) |
| 17 | + |
| 18 | + |
| 19 | +@mcp.tool() |
| 20 | +def target_generator(trial_number: int, bbox_to_anchor_y: float) -> str: |
| 21 | + """ |
| 22 | + Generate a plot figure based on the trial suggested by Optuna MCP. |
| 23 | +
|
| 24 | + Args: |
| 25 | + trial_number: The trial number. |
| 26 | + bbox_to_anchor_y: |
| 27 | + The `bbox_to_anchor_y` stored in `params` of a `trial` suggested by Optuna MCP. |
| 28 | + """ |
| 29 | + fig, axes = plt.subplots(ncols=2, nrows=2, figsize=(10, 5), sharex=True) |
| 30 | + dx = np.arange(100) + 1 |
| 31 | + for i, d in enumerate([5, 10, 20, 40]): |
| 32 | + ax = axes[i // 2][i % 2] |
| 33 | + |
| 34 | + def _subplot(ax: plt.Axes, values: list[list[float]]) -> plt.Line2D: |
| 35 | + cum_values = np.minimum.accumulate(values, axis=-1) |
| 36 | + mean = np.mean(cum_values, axis=0) |
| 37 | + stderr = np.std(cum_values, axis=0) / np.sqrt(len(values)) |
| 38 | + (line,) = ax.plot(dx, mean) |
| 39 | + ax.fill_between(dx, mean - stderr, mean + stderr, alpha=0.2) |
| 40 | + return line |
| 41 | + |
| 42 | + lines = [] |
| 43 | + ax.set_title(f"{d}D") |
| 44 | + lines.append(_subplot(ax, np.sum((X1[..., :d] - 2) ** 2, axis=-1))) |
| 45 | + lines.append(_subplot(ax, np.sum((X2[..., :d] - 2) ** 2, axis=-1))) |
| 46 | + |
| 47 | + fig.supxlabel("Number of Trials") |
| 48 | + fig.supylabel("Objective Values") |
| 49 | + labels = ["Uniform", "Gaussian"] |
| 50 | + loc = "lower center" |
| 51 | + bbox_to_anchor = (0.5, bbox_to_anchor_y) |
| 52 | + fig.legend(handles=lines, labels=labels, loc=loc, ncols=2, bbox_to_anchor=bbox_to_anchor) |
| 53 | + fig_path = f"figs/fig{trial_number}.png" |
| 54 | + plt.savefig(fig_path, bbox_inches="tight") |
| 55 | + return f"{fig_path} generated for Trial {trial_number} with {bbox_to_anchor_y=}" |
| 56 | + |
| 57 | + |
| 58 | +if __name__ == "__main__": |
| 59 | + mcp.run() |
0 commit comments