Skip to content

Commit c95a15f

Browse files
aobo-ymeta-codesync[bot]
authored andcommitted
Suport 2D image attr in visualize_image_attr (#1761)
Summary: Pull Request resolved: #1761 Objective: plan to unify the image attr logic between llm attr and the previous visualize_image_attr. But visualize_image_attr is mainly written aiming for gradient-based attr. So this may take a while. As the 1st step: - Revise `_normalize_attr` for readability. - Remove `reduction` from `_normalize_attr` (how to aggregate attr (mean/sum/max) is not a part of normalization) - Support `(H, W)` attr Reviewed By: craymichael Differential Revision: D90046963 fbshipit-source-id: 82dddaa591bc395ca8c511db17a19522213c2f3e
1 parent badd345 commit c95a15f

File tree

1 file changed

+29
-28
lines changed

1 file changed

+29
-28
lines changed

captum/attr/_utils/visualization.py

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -197,34 +197,32 @@ def _cumulative_sum_threshold(
197197

198198

199199
def _normalize_attr(
200-
attr: npt.NDArray,
200+
attr: npt.NDArray, # 2D (H, W)
201201
sign: str,
202202
outlier_perc: Union[int, float] = 2,
203-
reduction_axis: Optional[int] = None,
204203
) -> npt.NDArray:
205-
attr_combined = attr
206-
if reduction_axis is not None:
207-
attr_combined = np.sum(attr, axis=reduction_axis)
208-
209-
# Choose appropriate signed values and rescale, removing given outlier percentage.
210-
if VisualizeSign[sign].value == VisualizeSign.all.value:
211-
threshold = _cumulative_sum_threshold(
212-
np.abs(attr_combined), 100.0 - outlier_perc
213-
)
214-
elif VisualizeSign[sign].value == VisualizeSign.positive.value:
215-
attr_combined = (attr_combined > 0) * attr_combined
216-
threshold = _cumulative_sum_threshold(attr_combined, 100.0 - outlier_perc)
217-
elif VisualizeSign[sign].value == VisualizeSign.negative.value:
218-
attr_combined = (attr_combined < 0) * attr_combined
219-
threshold = -1 * _cumulative_sum_threshold(
220-
np.abs(attr_combined), 100.0 - outlier_perc
221-
)
222-
elif VisualizeSign[sign].value == VisualizeSign.absolute_value.value:
223-
attr_combined = np.abs(attr_combined)
224-
threshold = _cumulative_sum_threshold(attr_combined, 100.0 - outlier_perc)
204+
sign_type = VisualizeSign[sign]
205+
206+
# Apply sign-specific transformation to filter/transform attribution values
207+
if sign_type == VisualizeSign.all:
208+
pass # Keep all values as-is
209+
elif sign_type == VisualizeSign.positive:
210+
attr = np.maximum(attr, 0)
211+
elif sign_type == VisualizeSign.negative:
212+
attr = np.minimum(attr, 0)
213+
elif sign_type == VisualizeSign.absolute_value:
214+
attr = np.abs(attr)
225215
else:
226216
raise AssertionError("Visualize Sign type is not valid.")
227-
return _normalize_scale(attr_combined, threshold)
217+
218+
# Compute threshold from absolute values, removing given outlier percentage
219+
threshold = _cumulative_sum_threshold(np.abs(attr), 100.0 - outlier_perc)
220+
221+
# For negative sign, threshold should be negative to match the sign of values
222+
if sign_type == VisualizeSign.negative:
223+
threshold = -threshold
224+
225+
return _normalize_scale(attr, threshold)
228226

229227

230228
def _create_default_plot(
@@ -371,9 +369,8 @@ def visualize_image_attr(
371369
Args:
372370
373371
attr (numpy.ndarray): Numpy array corresponding to attributions to be
374-
visualized. Shape must be in the form (H, W, C), with
375-
channels as last dimension. Shape must also match that of
376-
the original image if provided.
372+
visualized. Shape must be in the form (H, W, C) or (H, W).
373+
Shape must also match that of the original image if provided.
377374
original_image (numpy.ndarray, optional): Numpy array corresponding to
378375
original image. Shape must be in the form (H, W, C), with
379376
channels as the last dimension. Image can be provided either
@@ -510,8 +507,12 @@ def visualize_image_attr(
510507
"alpha_scaling": _visualize_alpha_scaling,
511508
"original_image": _visualize_original_image,
512509
}
510+
511+
# if the attr contains channel, aggregate them by sum
512+
if len(attr.shape) == 3:
513+
attr = np.sum(attr, axis=2)
513514
# Choose appropriate signed attributions and normalize.
514-
norm_attr = _normalize_attr(attr, sign, outlier_perc, reduction_axis=2)
515+
norm_attr = _normalize_attr(attr, sign, outlier_perc)
515516

516517
# Set default colormap and bounds based on sign.
517518
default_cmap, vmin, vmax = _initialize_cmap_and_vmin_vmax(sign)
@@ -971,7 +972,7 @@ def visualize_timeseries_attr(
971972
else:
972973
plt_axis_list = plt_axis
973974

974-
norm_attr = _normalize_attr(attr, sign, outlier_perc, reduction_axis=None)
975+
norm_attr = _normalize_attr(attr, sign, outlier_perc)
975976

976977
# Set default colormap and bounds based on sign.
977978
default_cmap, vmin, vmax = _initialize_cmap_and_vmin_vmax(sign)

0 commit comments

Comments
 (0)