|
1 | | -import { Card, CardContent, Grid, Typography, useTheme } from "@mui/material" |
2 | | -import * as Optuna from "@optuna/types" |
| 1 | +import { Grid, useTheme } from "@mui/material" |
| 2 | +import { PlotTimeline } from "@optuna/react" |
3 | 3 | import * as plotly from "plotly.js-dist-min" |
4 | 4 | import React, { FC, useEffect } from "react" |
5 | | -import { StudyDetail, Trial } from "ts/types/optuna" |
| 5 | +import { StudyDetail } from "ts/types/optuna" |
6 | 6 | import { PlotType } from "../apiClient" |
7 | | -import { makeHovertext } from "../graphUtil" |
| 7 | +import { studyDetailToStudy } from "../graphUtil" |
8 | 8 | import { usePlot } from "../hooks/usePlot" |
9 | 9 | import { usePlotlyColorTheme } from "../state" |
10 | 10 | import { useBackendRender } from "../state" |
11 | 11 |
|
12 | 12 | const plotDomId = "graph-timeline" |
13 | | -const maxBars = 100 |
14 | 13 |
|
15 | 14 | export const GraphTimeline: FC<{ |
16 | 15 | study: StudyDetail | null |
17 | 16 | }> = ({ study }) => { |
| 17 | + const theme = useTheme() |
| 18 | + const colorTheme = usePlotlyColorTheme(theme.palette.mode) |
| 19 | + |
18 | 20 | if (useBackendRender()) { |
19 | 21 | return <GraphTimelineBackend study={study} /> |
20 | 22 | } else { |
21 | | - return <GraphTimelineFrontend study={study} /> |
| 23 | + return ( |
| 24 | + <PlotTimeline study={studyDetailToStudy(study)} colorTheme={colorTheme} /> |
| 25 | + ) |
22 | 26 | } |
23 | 27 | } |
24 | 28 |
|
@@ -51,176 +55,3 @@ const GraphTimelineBackend: FC<{ |
51 | 55 | </Grid> |
52 | 56 | ) |
53 | 57 | } |
54 | | - |
55 | | -const GraphTimelineFrontend: FC<{ |
56 | | - study: StudyDetail | null |
57 | | -}> = ({ study }) => { |
58 | | - const theme = useTheme() |
59 | | - const colorTheme = usePlotlyColorTheme(theme.palette.mode) |
60 | | - |
61 | | - const trials = study?.trials ?? [] |
62 | | - |
63 | | - useEffect(() => { |
64 | | - if (study !== null) { |
65 | | - plotTimeline(trials, colorTheme) |
66 | | - } |
67 | | - }, [trials, colorTheme]) |
68 | | - |
69 | | - return ( |
70 | | - <Card> |
71 | | - <CardContent> |
72 | | - <Typography |
73 | | - variant="h6" |
74 | | - sx={{ margin: "1em 0", fontWeight: theme.typography.fontWeightBold }} |
75 | | - > |
76 | | - Timeline |
77 | | - </Typography> |
78 | | - <Grid item xs={9}> |
79 | | - <div id={plotDomId} /> |
80 | | - </Grid> |
81 | | - </CardContent> |
82 | | - </Card> |
83 | | - ) |
84 | | -} |
85 | | - |
86 | | -const plotTimeline = ( |
87 | | - trials: Trial[], |
88 | | - colorTheme: Partial<Plotly.Template> |
89 | | -) => { |
90 | | - if (document.getElementById(plotDomId) === null) { |
91 | | - return |
92 | | - } |
93 | | - |
94 | | - if (trials.length === 0) { |
95 | | - plotly.react(plotDomId, [], { |
96 | | - template: colorTheme, |
97 | | - }) |
98 | | - return |
99 | | - } |
100 | | - |
101 | | - const cm: Record<Optuna.TrialState, string> = { |
102 | | - Complete: "blue", |
103 | | - Fail: "red", |
104 | | - Pruned: "orange", |
105 | | - Running: "green", |
106 | | - Waiting: "gray", |
107 | | - } |
108 | | - const runningKey = "Running" |
109 | | - |
110 | | - const lastTrials = trials.slice(-maxBars) // To only show last elements |
111 | | - const minDatetime = new Date( |
112 | | - Math.min( |
113 | | - ...lastTrials.map( |
114 | | - (t) => t.datetime_start?.getTime() ?? new Date().getTime() |
115 | | - ) |
116 | | - ) |
117 | | - ) |
118 | | - const maxRunDuration = Math.max( |
119 | | - ...trials.map((t) => { |
120 | | - return t.datetime_start === undefined || t.datetime_complete === undefined |
121 | | - ? -Infinity |
122 | | - : t.datetime_complete.getTime() - t.datetime_start.getTime() |
123 | | - }) |
124 | | - ) |
125 | | - const hasRunning = |
126 | | - (maxRunDuration === -Infinity && |
127 | | - trials.some((t) => t.state === runningKey)) || |
128 | | - trials.some((t) => { |
129 | | - if (t.state !== runningKey) { |
130 | | - return false |
131 | | - } |
132 | | - const now = new Date().getTime() |
133 | | - const start = t.datetime_start?.getTime() ?? now |
134 | | - // This is an ad-hoc handling to check if the trial is running. |
135 | | - // We do not check via `trialState` because some trials may have state=RUNNING, |
136 | | - // even if they are not running because of unexpected job kills. |
137 | | - // In this case, we would like to ensure that these trials will not squash the timeline plot |
138 | | - // for the other trials. |
139 | | - return now - start < maxRunDuration * 5 |
140 | | - }) |
141 | | - const maxDatetime = hasRunning |
142 | | - ? new Date() |
143 | | - : new Date( |
144 | | - Math.max( |
145 | | - ...lastTrials.map( |
146 | | - (t) => t.datetime_complete?.getTime() ?? minDatetime.getTime() |
147 | | - ) |
148 | | - ) |
149 | | - ) |
150 | | - const layout: Partial<plotly.Layout> = { |
151 | | - margin: { |
152 | | - l: 50, |
153 | | - t: 0, |
154 | | - r: 50, |
155 | | - b: 0, |
156 | | - }, |
157 | | - xaxis: { |
158 | | - title: "Datetime", |
159 | | - type: "date", |
160 | | - range: [minDatetime, maxDatetime], |
161 | | - }, |
162 | | - yaxis: { |
163 | | - title: "Trial", |
164 | | - range: [lastTrials[0].number, lastTrials[0].number + lastTrials.length], |
165 | | - }, |
166 | | - uirevision: "true", |
167 | | - template: colorTheme, |
168 | | - legend: { |
169 | | - x: 1.0, |
170 | | - y: 0.95, |
171 | | - }, |
172 | | - } |
173 | | - |
174 | | - const makeTrace = (bars: Trial[], state: string, color: string) => { |
175 | | - const isRunning = state === runningKey |
176 | | - // Waiting trials should not squash other trials, so use `maxDatetime` instead of `new Date()`. |
177 | | - const starts = bars.map((b) => b.datetime_start ?? maxDatetime) |
178 | | - const runDurations = bars.map((b, i) => { |
179 | | - const startTime = starts[i].getTime() |
180 | | - const completeTime = isRunning |
181 | | - ? maxDatetime.getTime() |
182 | | - : b.datetime_complete?.getTime() ?? startTime |
183 | | - // By using 1 as the min value, we can recognize these bars at least when zooming in. |
184 | | - return Math.max(1, completeTime - startTime) |
185 | | - }) |
186 | | - const trace: Partial<plotly.PlotData> = { |
187 | | - type: "bar", |
188 | | - x: runDurations, |
189 | | - y: bars.map((b) => b.number), |
190 | | - // @ts-ignore: To suppress ts(2322) |
191 | | - base: starts, |
192 | | - name: state, |
193 | | - text: bars.map((b) => makeHovertext(b)), |
194 | | - hovertemplate: "%{text}<extra>" + state + "</extra>", |
195 | | - orientation: "h", |
196 | | - marker: { color: color }, |
197 | | - textposition: "none", // Avoid drawing hovertext in a bar. |
198 | | - } |
199 | | - return trace |
200 | | - } |
201 | | - |
202 | | - const traces: Partial<plotly.PlotData>[] = [] |
203 | | - for (const [state, color] of Object.entries(cm)) { |
204 | | - const bars = trials.filter((t) => t.state === state) |
205 | | - if (bars.length === 0) { |
206 | | - continue |
207 | | - } |
208 | | - if (state === "Complete") { |
209 | | - const feasibleTrials = bars.filter((t) => |
210 | | - t.constraints.every((c) => c <= 0) |
211 | | - ) |
212 | | - const infeasibleTrials = bars.filter((t) => |
213 | | - t.constraints.some((c) => c > 0) |
214 | | - ) |
215 | | - if (feasibleTrials.length > 0) { |
216 | | - traces.push(makeTrace(feasibleTrials, "Complete", color)) |
217 | | - } |
218 | | - if (infeasibleTrials.length > 0) { |
219 | | - traces.push(makeTrace(infeasibleTrials, "Infeasible", "#cccccc")) |
220 | | - } |
221 | | - } else { |
222 | | - traces.push(makeTrace(bars, state, color)) |
223 | | - } |
224 | | - } |
225 | | - plotly.react(plotDomId, traces, layout) |
226 | | -} |
0 commit comments