-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathFitPlot.tsx
More file actions
108 lines (106 loc) · 3.19 KB
/
Copy pathFitPlot.tsx
File metadata and controls
108 lines (106 loc) · 3.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import React from "react";
import Plot from "@/plotly/Plot";
import type { FitPlotMode } from "./types";
import { plotLayoutBase, sortXY } from "./plotUtils";
export function FitPlot(props: {
prefersDark: boolean;
mode: FitPlotMode;
hasVal: boolean;
trainActual: number[];
valActual: number[];
trainYhat: number[];
valYhat: number[];
trainXY: { x: number[]; y: number[] };
valXY: { x: number[]; y: number[] };
}): React.ReactElement {
if (props.mode === "curve_1d") {
const trainCurve = sortXY(props.trainXY.x, props.trainYhat);
const valCurve = sortXY(props.valXY.x, props.valYhat);
return (
<Plot
data={[
{
x: props.trainXY.x,
y: props.trainXY.y,
type: "scatter",
mode: "markers",
name: "train data",
marker: { size: 6, color: "#4f7cff", opacity: 0.7 }
},
{
x: trainCurve.x,
y: trainCurve.y,
type: "scatter",
mode: "lines",
name: "train model",
line: { color: "#4f7cff", width: 2 }
},
...(props.hasVal
? ([
{
x: props.valXY.x,
y: props.valXY.y,
type: "scatter",
mode: "markers",
name: "val data",
marker: { size: 6, color: "#ff7c7c", opacity: 0.7 }
},
{
x: valCurve.x,
y: valCurve.y,
type: "scatter",
mode: "lines",
name: "val model",
line: { color: "#ff7c7c", width: 2 }
}
] as any[])
: [])
]}
layout={{
...plotLayoutBase(props.prefersDark),
autosize: true,
margin: { l: 50, r: 20, t: 20, b: 50 },
xaxis: { ...(plotLayoutBase(props.prefersDark) as any).xaxis, title: "x" },
yaxis: { ...(plotLayoutBase(props.prefersDark) as any).yaxis, title: "y" }
}}
style={{ width: "100%", height: "100%" }}
config={{ displayModeBar: false, responsive: true }}
/>
);
}
return (
<Plot
data={[
{
x: props.trainActual,
y: props.trainYhat,
type: "scatter",
mode: "markers",
name: "train",
marker: { size: 6, color: "#4f7cff" }
},
...(props.hasVal
? ([
{
x: props.valActual,
y: props.valYhat,
type: "scatter",
mode: "markers",
name: "val",
marker: { size: 6, color: "#ff7c7c" }
} as any
] as any[])
: [])
]}
layout={{
...plotLayoutBase(props.prefersDark),
autosize: true,
margin: { l: 50, r: 20, t: 20, b: 50 },
xaxis: { ...(plotLayoutBase(props.prefersDark) as any).xaxis, title: "y (actual)" },
yaxis: { ...(plotLayoutBase(props.prefersDark) as any).yaxis, title: "ŷ (predicted)" }
}}
style={{ width: "100%", height: "100%" }}
config={{ displayModeBar: false, responsive: true }}
/>
);
}