|
| 1 | +import * as React from 'react'; |
| 2 | +import { ScatterChart } from '@mui/x-charts/ScatterChart'; |
| 3 | +import { ChartsTooltipContainer, useItemTooltip } from '@mui/x-charts/ChartsTooltip'; |
| 4 | +import Stack from '@mui/material/Stack'; |
| 5 | +import Box from '@mui/material/Box'; |
| 6 | +import Divider from '@mui/material/Divider'; |
| 7 | +import Paper from '@mui/material/Paper'; |
| 8 | +import Typography from '@mui/material/Typography'; |
| 9 | +import { useZAxis } from '@mui/x-charts/hooks'; |
| 10 | +import { irisData } from '../dataset/irisDataset'; |
| 11 | + |
| 12 | +const species = ['Setosa', 'Versicolor', 'Virginica']; |
| 13 | +const speciesColors = ['#2e7d32', '#ed6c02', '#9c27b0']; |
| 14 | +const speciesPredictionColors = ['#5aa35e', '#be7f4b', '#ba68c9']; |
| 15 | + |
| 16 | +function IrisMarker(props) { |
| 17 | + const { x, y, color, size, isHighlighted, isFaded, dataIndex, ...other } = props; |
| 18 | + |
| 19 | + // Get and use color/size scale to determine stroke color and width based on prediction and confidence z-axes. |
| 20 | + const predictionAxis = useZAxis('prediction'); |
| 21 | + const confidenceAxis = useZAxis('confidence'); |
| 22 | + const strokeWidth = |
| 23 | + confidenceAxis?.sizeScale?.(confidenceAxis.data?.[dataIndex]) ?? 0; |
| 24 | + const strokeColor = |
| 25 | + predictionAxis?.colorScale?.(predictionAxis.data?.[dataIndex]) ?? 'gray'; |
| 26 | + |
| 27 | + const r = (isHighlighted ? 1.2 : 1) * size; |
| 28 | + return ( |
| 29 | + <circle |
| 30 | + cx={0} |
| 31 | + cy={0} |
| 32 | + r={r} |
| 33 | + transform={`translate(${x}, ${y})`} |
| 34 | + fill={color} |
| 35 | + fillOpacity={isFaded ? 0.15 : 0.35} |
| 36 | + stroke={strokeColor} |
| 37 | + strokeWidth={strokeWidth} |
| 38 | + strokeOpacity={isFaded ? 0.3 : 1} |
| 39 | + cursor={other.onClick ? 'pointer' : 'unset'} |
| 40 | + {...other} |
| 41 | + /> |
| 42 | + ); |
| 43 | +} |
| 44 | + |
| 45 | +/** |
| 46 | + * Explains how the scatter plot encodes multiple dimensions of the dataset. |
| 47 | + */ |
| 48 | +function IrisAnnotation() { |
| 49 | + return ( |
| 50 | + <Typography variant="caption" sx={{ textAlign: 'center' }}> |
| 51 | + Fill color represents the actual species, stroke color represents the predicted |
| 52 | + species, bubble size represents petal length, and stroke width represents |
| 53 | + prediction confidence. |
| 54 | + </Typography> |
| 55 | + ); |
| 56 | +} |
| 57 | + |
| 58 | +function IrisTooltip() { |
| 59 | + return ( |
| 60 | + <ChartsTooltipContainer trigger="item"> |
| 61 | + <IrisTooltipContent /> |
| 62 | + </ChartsTooltipContainer> |
| 63 | + ); |
| 64 | +} |
| 65 | + |
| 66 | +function IrisTooltipContent() { |
| 67 | + const item = useItemTooltip(); |
| 68 | + |
| 69 | + function numberFormatter(value) { |
| 70 | + if (typeof value === 'number') { |
| 71 | + return new Intl.NumberFormat('en-US').format(value); |
| 72 | + } |
| 73 | + return String(value); |
| 74 | + } |
| 75 | + |
| 76 | + // Get the full data item from irisData using dataIndex from identifier |
| 77 | + const dataIndex = item?.identifier.dataIndex; |
| 78 | + const dataItem = dataIndex !== undefined ? irisData[dataIndex] : null; |
| 79 | + |
| 80 | + if (!item || !dataItem) { |
| 81 | + return null; |
| 82 | + } |
| 83 | + |
| 84 | + return ( |
| 85 | + <Paper sx={{ p: 1.5 }} elevation={4}> |
| 86 | + <Box |
| 87 | + sx={{ display: 'flex', flexDirection: 'row', alignItems: 'center', mb: 1 }} |
| 88 | + > |
| 89 | + <Box |
| 90 | + sx={{ |
| 91 | + width: 20, |
| 92 | + height: 20, |
| 93 | + backgroundColor: item?.color, |
| 94 | + borderRadius: 1, |
| 95 | + mr: 2, |
| 96 | + }} |
| 97 | + /> |
| 98 | + <Typography sx={{ fontWeight: 600 }}>{dataItem.species}</Typography> |
| 99 | + </Box> |
| 100 | + <Divider sx={{ my: 1 }} /> |
| 101 | + <Box sx={{ display: 'grid', gridTemplateColumns: '1fr', gap: 0.75 }}> |
| 102 | + <Box sx={{ display: 'flex', justifyContent: 'space-between' }}> |
| 103 | + <Typography variant="caption">Sepal Length:</Typography> |
| 104 | + <Typography variant="caption" sx={{ fontWeight: 500 }}> |
| 105 | + {numberFormatter(dataItem.sepalLength)} cm |
| 106 | + </Typography> |
| 107 | + </Box> |
| 108 | + <Box sx={{ display: 'flex', justifyContent: 'space-between' }}> |
| 109 | + <Typography variant="caption">Sepal Width:</Typography> |
| 110 | + <Typography variant="caption" sx={{ fontWeight: 500 }}> |
| 111 | + {numberFormatter(dataItem.sepalWidth)} cm |
| 112 | + </Typography> |
| 113 | + </Box> |
| 114 | + <Box sx={{ display: 'flex', justifyContent: 'space-between' }}> |
| 115 | + <Typography variant="caption">Petal Length:</Typography> |
| 116 | + <Typography variant="caption" sx={{ fontWeight: 500 }}> |
| 117 | + {numberFormatter(dataItem.petalLength)} cm |
| 118 | + </Typography> |
| 119 | + </Box> |
| 120 | + <Box sx={{ display: 'flex', justifyContent: 'space-between' }}> |
| 121 | + <Typography variant="caption">Petal Width:</Typography> |
| 122 | + <Typography variant="caption" sx={{ fontWeight: 500 }}> |
| 123 | + {numberFormatter(dataItem.petalWidth)} cm |
| 124 | + </Typography> |
| 125 | + </Box> |
| 126 | + <Divider sx={{ my: 0.5 }} /> |
| 127 | + <Box sx={{ display: 'flex', justifyContent: 'space-between' }}> |
| 128 | + <Typography variant="caption">Predicted:</Typography> |
| 129 | + <Typography variant="caption" sx={{ fontWeight: 500 }}> |
| 130 | + {dataItem.prediction} |
| 131 | + </Typography> |
| 132 | + </Box> |
| 133 | + <Box sx={{ display: 'flex', justifyContent: 'space-between' }}> |
| 134 | + <Typography variant="caption">Confidence:</Typography> |
| 135 | + <Typography variant="caption" sx={{ fontWeight: 500 }}> |
| 136 | + {numberFormatter(dataItem.confidence)}% |
| 137 | + </Typography> |
| 138 | + </Box> |
| 139 | + </Box> |
| 140 | + </Paper> |
| 141 | + ); |
| 142 | +} |
| 143 | + |
| 144 | +export default function ScatterZAxis() { |
| 145 | + return ( |
| 146 | + <Stack sx={{ width: '100%' }}> |
| 147 | + <ScatterChart |
| 148 | + dataset={irisData} |
| 149 | + height={300} |
| 150 | + grid={{ horizontal: true, vertical: true }} |
| 151 | + series={[ |
| 152 | + { |
| 153 | + id: 'data', |
| 154 | + colorAxisId: 'species', |
| 155 | + sizeAxisId: 'petal', |
| 156 | + datasetKeys: { x: 'sepalLength', y: 'sepalWidth' }, |
| 157 | + highlightScope: { highlight: 'item', fade: 'global' }, |
| 158 | + }, |
| 159 | + ]} |
| 160 | + xAxis={[{ label: 'Sepal length (cm)' }]} |
| 161 | + yAxis={[{ label: 'Sepal width (cm)' }]} |
| 162 | + zAxis={[ |
| 163 | + { |
| 164 | + id: 'species', |
| 165 | + dataKey: 'species', |
| 166 | + colorMap: { |
| 167 | + type: 'ordinal', |
| 168 | + values: species, |
| 169 | + colors: speciesColors, |
| 170 | + }, |
| 171 | + }, |
| 172 | + { |
| 173 | + id: 'petal', |
| 174 | + dataKey: 'petalLength', |
| 175 | + sizeMap: { |
| 176 | + type: 'continuous', |
| 177 | + min: 1, |
| 178 | + max: 7, |
| 179 | + size: [3, 10], |
| 180 | + }, |
| 181 | + }, |
| 182 | + { |
| 183 | + id: 'prediction', |
| 184 | + dataKey: 'prediction', |
| 185 | + colorMap: { |
| 186 | + type: 'ordinal', |
| 187 | + values: species, |
| 188 | + colors: speciesPredictionColors, |
| 189 | + }, |
| 190 | + }, |
| 191 | + { |
| 192 | + id: 'confidence', |
| 193 | + dataKey: 'confidence', |
| 194 | + sizeMap: { |
| 195 | + type: 'continuous', |
| 196 | + min: Math.min(...irisData.map((item) => item.confidence)), |
| 197 | + max: 100, |
| 198 | + size: [0.5, 3], |
| 199 | + }, |
| 200 | + }, |
| 201 | + ]} |
| 202 | + slots={{ marker: IrisMarker, tooltip: IrisTooltip }} |
| 203 | + ></ScatterChart> |
| 204 | + <IrisAnnotation /> |
| 205 | + </Stack> |
| 206 | + ); |
| 207 | +} |
0 commit comments