Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 66 additions & 29 deletions ax/analysis/plotly/sensitivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
from plotly import express as px, graph_objects as go
from pyre_extensions import assert_is_instance, none_throws, override

# SensitivityAnalysisPlot uses a plotly bar chart which needs especially short labels
MAX_LABEL_LEN: int = 20
# Maximum characters per line for y-axis labels before wrapping with <br>
MAX_LINE_LEN: int = 25

SENSITIVITY_CARDGROUP_TITLE = (
"Sensitivity Analysis: Understand how each parameter affects metrics"
Expand Down Expand Up @@ -148,7 +148,7 @@ def compute(

# If a human readable metric name is provided, use it
metric_label = self.labels.get(
metric_name, truncate_label(label=metric_name, n=MAX_LABEL_LEN)
metric_name, truncate_label(label=metric_name, n=MAX_LINE_LEN)
)
df, fig = _prepare_card_components(
data=data,
Expand Down Expand Up @@ -218,6 +218,45 @@ def compute_sensitivity_adhoc(
)


def _wrap_label(name: str, max_line_len: int = MAX_LINE_LEN) -> str:
"""Wrap long parameter names using <br> for multi-line y-axis labels.

For interaction effects (containing " & "), each parameter is placed on its
own line. For single parameter names that exceed max_line_len, the name is
wrapped at underscores.
"""
if " & " in name:
parts = name.split(" & ")
wrapped_parts = [_wrap_single(p, max_line_len) for p in parts]
return " &<br>".join(wrapped_parts)
return _wrap_single(name, max_line_len)


def _wrap_single(name: str, max_line_len: int = MAX_LINE_LEN) -> str:
"""Wrap a single parameter name at underscores if it exceeds max_line_len.

The underscore at the wrap point is preserved as a leading underscore on the
next line, so the full name can be reconstructed by removing ``<br>`` tags.
"""
if len(name) <= max_line_len:
return name
segments = name.split("_")
lines: list[str] = []
current_line = ""
for segment in segments:
candidate = f"{current_line}_{segment}" if current_line else segment
if len(candidate) > max_line_len and current_line:
lines.append(current_line)
current_line = segment
else:
current_line = candidate
if current_line:
lines.append(current_line)
# Re-join with "<br>_" so the underscore at each break point is preserved
# on the next line, making the label visually faithful to the original name.
return "<br>_".join(lines)


def _prepare_data(
adapter: TorchAdapter,
metric_name: str,
Expand All @@ -229,9 +268,10 @@ def _prepare_data(
metrics=[metric_name],
order=order,
exclude_map_key=exclude_map_key,
exclude_task=True,
)

return pd.DataFrame.from_records(
df = pd.DataFrame.from_records(
[
{
"metric_name": metric_name,
Expand All @@ -243,6 +283,15 @@ def _prepare_data(
]
)

# Re-normalize sensitivities so absolute values sum to 1 per metric.
for mn in df["metric_name"].unique():
mask = df["metric_name"] == mn
total = df.loc[mask, "sensitivity"].abs().sum()
if total > 0:
df.loc[mask, "sensitivity"] = df.loc[mask, "sensitivity"] / total

return df


def _prepare_card_components(
data: pd.DataFrame,
Expand All @@ -254,33 +303,21 @@ def _prepare_card_components(
["parameter_name", "sensitivity"]
].copy()

# If the parameter name is too long, truncate it.
# If the parameter name is a second order interaction, truncate each parameter name
# separately then re-combine.
# If the truncated parameter name already exists, append count at end to prevent
# collisions.
# TODO: @paschali @mgarrard clean up after implementing parameter canonical names
# Wrap long parameter names using <br> for multi-line y-axis labels.
# If the wrapped name collides with an existing one, append a count suffix.
param_names = plotting_df["parameter_name"].unique()
param_to_shortened_name = {}
shortened_name_count = {}
param_to_display_name: dict[str, str] = {}
display_name_count: dict[str, int] = {}
for name in param_names:
shortened_name = (
" & ".join(
truncate_label(label=sub_name, n=MAX_LABEL_LEN // 2)
for sub_name in name.split(" & ")
)
if "&" in name
else truncate_label(label=name, n=MAX_LABEL_LEN)
)
# track number of times each shortened name is seen
if shortened_name not in shortened_name_count:
shortened_name_count[shortened_name] = 0
display_name = _wrap_label(name)
if display_name not in display_name_count:
display_name_count[display_name] = 0
else:
shortened_name_count[shortened_name] += 1
shortened_name = shortened_name + f"_{shortened_name_count[shortened_name]}"
param_to_shortened_name[name] = shortened_name
plotting_df["truncated_parameter_name"] = plotting_df["parameter_name"].map(
param_to_shortened_name
display_name_count[display_name] += 1
display_name = display_name + f"_{display_name_count[display_name]}"
param_to_display_name[name] = display_name
plotting_df["display_parameter_name"] = plotting_df["parameter_name"].map(
param_to_display_name
)

plotting_df["importance"] = plotting_df["sensitivity"].abs()
Expand All @@ -292,7 +329,7 @@ def _prepare_card_components(
.reset_index()
.head(top_k),
x="importance",
y="truncated_parameter_name",
y="display_parameter_name",
orientation="h",
color="direction",
color_discrete_map={
Expand Down
65 changes: 65 additions & 0 deletions ax/analysis/plotly/tests/test_sensitivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ax.adapter.torch import TorchAdapter
from ax.analysis.plotly.sensitivity import (
_prepare_data,
_wrap_label,
compute_sensitivity_adhoc,
SensitivityAnalysisPlot,
)
Expand Down Expand Up @@ -222,3 +223,67 @@ def test_exclude_map_key(self) -> None:
exclude_map_key=True,
).flatten()
self.assertEqual(len(cards), 1)

@mock_botorch_optimize
def test_task_params_excluded(self) -> None:
"""Test that _prepare_data passes exclude_task=True to ax_parameter_sens."""
client = get_test_client()
adapter = Generators.BOTORCH_MODULAR(
experiment=client.experiment, data=client.experiment.lookup_data()
)

mock_results = {"bar": {"x": 0.6}}

with patch(
"ax.analysis.plotly.sensitivity.ax_parameter_sens",
return_value=mock_results,
) as mock_sens:
_prepare_data(
adapter=assert_is_instance(adapter, TorchAdapter),
metric_name="bar",
order="first",
)
mock_sens.assert_called_once_with(
adapter=assert_is_instance(adapter, TorchAdapter),
metrics=["bar"],
order="first",
exclude_map_key=True,
exclude_task=True,
)

def test_wrap_label(self) -> None:
cases = [
("short name unchanged", "x", "x"),
(
"long name wrapped at underscores",
"very_long_parameter_name_that_exceeds_limit",
None, # checked separately
),
(
"interaction effect split across lines",
"param_one & param_two",
"param_one &<br>param_two",
),
(
"no underscores returned as-is",
"a" * 40,
"a" * 40,
),
]
for desc, name, expected in cases:
with self.subTest(desc):
wrapped = _wrap_label(name)
if expected is not None:
self.assertEqual(wrapped, expected)
else:
# Long name: should contain <br> and reconstruct to original
self.assertIn("<br>", wrapped)
self.assertEqual(wrapped.replace("<br>", ""), name)

# Long interaction: each part independently wrapped
with self.subTest("long interaction effect"):
name = "very_long_parameter_name_alpha & very_long_parameter_name_beta"
wrapped = _wrap_label(name)
self.assertIn(" &<br>", wrapped)
parts = wrapped.split(" &<br>")
self.assertEqual(len(parts), 2)
21 changes: 20 additions & 1 deletion ax/utils/sensitivity/sobol_measures.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,6 +844,7 @@ def ax_parameter_sens(
order: str = "first",
signed: bool = True,
exclude_map_key: bool = True,
exclude_task: bool = False,
**sobol_kwargs: Any,
) -> dict[str, dict[str, npt.NDArray]]:
"""
Expand All @@ -868,6 +869,9 @@ def ax_parameter_sens(
excluded from sensitivity analysis by fixing it at the maximum step value.
This makes the sensitivity analysis more interpretable for users who
care about the effect of parameters on final performance.
exclude_task: If True, task parameters (those with ``is_task=True``, e.g.
synthetic parameters from the TrialAsTask transform) will be excluded
from the sensitivity results.
sobol_kwargs: keyword arguments passed on to SobolSensitivityGPMean, and if
signed, GpDGSMGpMean.

Expand Down Expand Up @@ -910,6 +914,21 @@ def ax_parameter_sens(
# Remove MAP_KEY from output feature names
output_feature_names = [f for f in feature_names if f != MAP_KEY]

# Exclude task parameters (e.g. TRIAL_PARAM from TrialAsTask transform)
# by fixing them at their target values.
if exclude_task and digest.task_features:
if fixed_features is None:
fixed_features = {}
for task_idx in digest.task_features:
if task_idx < len(feature_names):
fixed_features[task_idx] = float(
digest.target_values.get(task_idx, bounds[1, task_idx])
)
task_name = feature_names[task_idx]
output_feature_names = [
f for f in output_feature_names if f != task_name
]

# for second order indices, we need to compute first order indices first
# which is what is done here. With the first order indices, we can then subtract
# appropriately using the first-order indices to extract the second-order indices.
Expand Down Expand Up @@ -943,7 +962,7 @@ def ax_parameter_sens(
indices = array_with_string_indices_to_dict(
rows=metrics, cols=output_feature_names, A=ind.cpu().numpy()
)
if order == "second":
if order == "second" and len(output_feature_names) >= 2:
second_order_values = compute_sobol_indices_from_model_list(
model_list=model_list,
bounds=bounds,
Expand Down
4 changes: 4 additions & 0 deletions ax/utils/sensitivity/tests/test_sensitivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,7 @@ def test_SobolGPMean_SAASBO_Ax_utils(self) -> None:
order="total",
signed=False,
exclude_map_key=False,
exclude_task=False,
**sobol_kwargs,
)
ind_deriv = compute_derivatives_from_model_list(
Expand All @@ -432,6 +433,7 @@ def test_SobolGPMean_SAASBO_Ax_utils(self) -> None:
order="total",
signed=True,
exclude_map_key=False,
exclude_task=False,
**sobol_kwargs,
)
for i, pname in enumerate(["x1", "x2"]):
Expand Down Expand Up @@ -627,6 +629,7 @@ def test_ax_parameter_sens_exclude_map_key(self) -> None:
order="first",
signed=False,
exclude_map_key=True,
exclude_task=False,
**sobol_kwargs,
)

Expand All @@ -637,6 +640,7 @@ def test_ax_parameter_sens_exclude_map_key(self) -> None:
order="first",
signed=False,
exclude_map_key=False,
exclude_task=False,
**sobol_kwargs,
)

Expand Down
Loading