Skip to content

Commit 0047fca

Browse files
committed
fix: experimental classification create custom plot defaults to group by experiemnt group
Signed-off-by: Ilana Nguyen <inguyen@nvidia.com>
1 parent 1283bc4 commit 0047fca

File tree

3 files changed

+83
-80
lines changed

3 files changed

+83
-80
lines changed

src/aiperf/plot/core/plot_generator.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -587,7 +587,8 @@ def create_pareto_plot(
587587
group_data = df_sorted[df_sorted[group_by] == group]
588588
group_color = group_colors[group]
589589
# Use display name if available, otherwise use group ID
590-
group_name = display_names.get(group, group)
590+
# Convert to string to ensure compatibility with Plotly (handles numpy types)
591+
group_name = str(display_names.get(group, group))
591592

592593
# Calculate Pareto frontier for this group based on metric directions
593594
x_dir = self._get_metric_direction(x_metric)
@@ -763,7 +764,8 @@ def create_scatter_line_plot(
763764
group_data = df_sorted[df_sorted[group_by] == group]
764765
group_color = group_colors[group]
765766
# Use display name if available, otherwise use group ID
766-
group_name = display_names.get(group, group)
767+
# Convert to string to ensure compatibility with Plotly (handles numpy types)
768+
group_name = str(display_names.get(group, group))
767769

768770
# Determine shadow and main modes based on mode parameter
769771
shadow_mode = mode
@@ -864,7 +866,8 @@ def create_multi_run_bar_chart(
864866
else:
865867
group_data = df[df[group_by] == group]
866868
group_color = group_colors[group]
867-
group_name = display_names.get(group, group)
869+
# Convert to string to ensure compatibility with Plotly (handles numpy types)
870+
group_name = str(display_names.get(group, group))
868871

869872
r, g, b = mcolors.to_rgb(group_color)
870873
fillcolor = f"rgba({int(r * 255)}, {int(g * 255)}, {int(b * 255)}, 0.7)"

src/aiperf/plot/dashboard/builder.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1291,6 +1291,14 @@ def _build_custom_plot_modal(self) -> dbc.Modal:
12911291
{"label": "Model", "value": "model"},
12921292
{"label": "Concurrency", "value": "concurrency"},
12931293
]
1294+
1295+
# Add experiment group if experimental classification is enabled
1296+
exp_class_config = self.plot_config.get_experiment_classification_config()
1297+
if exp_class_config is not None:
1298+
group_by_options.append(
1299+
{"label": "Experiment Group", "value": "experiment_group"}
1300+
)
1301+
12941302
group_by_options.extend(metadata_options)
12951303

12961304
# Add swept parameters (exclude already listed options)
@@ -1299,6 +1307,7 @@ def _build_custom_plot_modal(self) -> dbc.Modal:
12991307
"endpoint_type",
13001308
"request_count",
13011309
"duration_seconds",
1310+
"experiment_group",
13021311
]:
13031312
display_name = param.replace("_", " ").replace(".", " ").title()
13041313
group_by_options.append({"label": display_name, "value": param})
@@ -1339,6 +1348,13 @@ def _build_custom_plot_modal(self) -> dbc.Modal:
13391348
),
13401349
dbc.ModalBody(
13411350
[
1351+
create_label("Plot Type", self.theme),
1352+
dcc.Dropdown(
1353+
id="custom-plot-type",
1354+
options=MULTI_RUN_PLOT_TYPES,
1355+
placeholder="Select plot type",
1356+
style={"font-size": "12px", "margin-bottom": "12px"},
1357+
),
13421358
create_label("X-Axis Metric", self.theme),
13431359
dcc.Dropdown(
13441360
id="custom-x-metric",
@@ -1385,13 +1401,6 @@ def _build_custom_plot_modal(self) -> dbc.Modal:
13851401
"margin-bottom": "16px",
13861402
},
13871403
),
1388-
create_label("Plot Type", self.theme),
1389-
dcc.Dropdown(
1390-
id="custom-plot-type",
1391-
options=MULTI_RUN_PLOT_TYPES,
1392-
placeholder="Select plot type",
1393-
style={"font-size": "12px", "margin-bottom": "12px"},
1394-
),
13951404
create_label("Label Points By", self.theme),
13961405
dcc.Dropdown(
13971406
id="custom-label-by",

0 commit comments

Comments
 (0)