@@ -197,34 +197,32 @@ def _cumulative_sum_threshold(
197197
198198
199199def _normalize_attr (
200- attr : npt .NDArray ,
200+ attr : npt .NDArray , # 2D (H, W)
201201 sign : str ,
202202 outlier_perc : Union [int , float ] = 2 ,
203- reduction_axis : Optional [int ] = None ,
204203) -> npt .NDArray :
205- attr_combined = attr
206- if reduction_axis is not None :
207- attr_combined = np .sum (attr , axis = reduction_axis )
208-
209- # Choose appropriate signed values and rescale, removing given outlier percentage.
210- if VisualizeSign [sign ].value == VisualizeSign .all .value :
211- threshold = _cumulative_sum_threshold (
212- np .abs (attr_combined ), 100.0 - outlier_perc
213- )
214- elif VisualizeSign [sign ].value == VisualizeSign .positive .value :
215- attr_combined = (attr_combined > 0 ) * attr_combined
216- threshold = _cumulative_sum_threshold (attr_combined , 100.0 - outlier_perc )
217- elif VisualizeSign [sign ].value == VisualizeSign .negative .value :
218- attr_combined = (attr_combined < 0 ) * attr_combined
219- threshold = - 1 * _cumulative_sum_threshold (
220- np .abs (attr_combined ), 100.0 - outlier_perc
221- )
222- elif VisualizeSign [sign ].value == VisualizeSign .absolute_value .value :
223- attr_combined = np .abs (attr_combined )
224- threshold = _cumulative_sum_threshold (attr_combined , 100.0 - outlier_perc )
204+ sign_type = VisualizeSign [sign ]
205+
206+ # Apply sign-specific transformation to filter/transform attribution values
207+ if sign_type == VisualizeSign .all :
208+ pass # Keep all values as-is
209+ elif sign_type == VisualizeSign .positive :
210+ attr = np .maximum (attr , 0 )
211+ elif sign_type == VisualizeSign .negative :
212+ attr = np .minimum (attr , 0 )
213+ elif sign_type == VisualizeSign .absolute_value :
214+ attr = np .abs (attr )
225215 else :
226216 raise AssertionError ("Visualize Sign type is not valid." )
227- return _normalize_scale (attr_combined , threshold )
217+
218+ # Compute threshold from absolute values, removing given outlier percentage
219+ threshold = _cumulative_sum_threshold (np .abs (attr ), 100.0 - outlier_perc )
220+
221+ # For negative sign, threshold should be negative to match the sign of values
222+ if sign_type == VisualizeSign .negative :
223+ threshold = - threshold
224+
225+ return _normalize_scale (attr , threshold )
228226
229227
230228def _create_default_plot (
@@ -371,9 +369,8 @@ def visualize_image_attr(
371369 Args:
372370
373371 attr (numpy.ndarray): Numpy array corresponding to attributions to be
374- visualized. Shape must be in the form (H, W, C), with
375- channels as last dimension. Shape must also match that of
376- the original image if provided.
372+ visualized. Shape must be in the form (H, W, C) or (H, W).
373+ Shape must also match that of the original image if provided.
377374 original_image (numpy.ndarray, optional): Numpy array corresponding to
378375 original image. Shape must be in the form (H, W, C), with
379376 channels as the last dimension. Image can be provided either
@@ -510,8 +507,12 @@ def visualize_image_attr(
510507 "alpha_scaling" : _visualize_alpha_scaling ,
511508 "original_image" : _visualize_original_image ,
512509 }
510+
511+ # if the attr contains channel, aggregate them by sum
512+ if len (attr .shape ) == 3 :
513+ attr = np .sum (attr , axis = 2 )
513514 # Choose appropriate signed attributions and normalize.
514- norm_attr = _normalize_attr (attr , sign , outlier_perc , reduction_axis = 2 )
515+ norm_attr = _normalize_attr (attr , sign , outlier_perc )
515516
516517 # Set default colormap and bounds based on sign.
517518 default_cmap , vmin , vmax = _initialize_cmap_and_vmin_vmax (sign )
@@ -971,7 +972,7 @@ def visualize_timeseries_attr(
971972 else :
972973 plt_axis_list = plt_axis
973974
974- norm_attr = _normalize_attr (attr , sign , outlier_perc , reduction_axis = None )
975+ norm_attr = _normalize_attr (attr , sign , outlier_perc )
975976
976977 # Set default colormap and bounds based on sign.
977978 default_cmap , vmin , vmax = _initialize_cmap_and_vmin_vmax (sign )
0 commit comments