Skip to content

Commit 2adc636

Browse files
authored
Merge pull request #41 from salesforce/fft
Fft
2 parents 6b39ea7 + d56f057 commit 2adc636

15 files changed

Lines changed: 366 additions & 54 deletions

File tree

omnixai/explainers/nlp/specific/ig.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ def __init__(
186186
preprocess_function: Callable,
187187
mode: str = "classification",
188188
id2token: Dict = None,
189+
tokenizer: Callable = None,
189190
**kwargs,
190191
):
191192
"""
@@ -196,7 +197,8 @@ def __init__(
196197
into the inputs of ``model``. The first output of ``preprocess_function`` must
197198
be the token ids.
198199
:param mode: The task type, e.g., `classification` or `regression`.
199-
:param id2token: The mapping from token ids to tokens.
200+
:param id2token: The mapping from token ids to tokens. If `tokenizer` is set, `id2token` will be ignored.
201+
:param tokenizer: The tokenizer for processing text inputs, i.e., tokenizers in HuggingFace.
200202
"""
201203
super().__init__()
202204
assert preprocess_function is not None, (
@@ -207,6 +209,7 @@ def __init__(
207209
self.embedding_layer = embedding_layer
208210
self.preprocess_function = preprocess_function
209211
self.id2token = id2token
212+
self.tokenizer = tokenizer
210213

211214
ig_class = None
212215
if is_torch_available():
@@ -293,11 +296,19 @@ def explain(self, X: Text, y=None, **kwargs) -> WordImportance:
293296
steps=steps,
294297
batch_size=batch_size
295298
)
296-
tokens = inputs[0].detach().cpu().numpy() if self.model_type == "torch" else inputs[0].numpy()
299+
tokens = inputs[0].detach().cpu().numpy() if self.model_type == "torch" \
300+
else inputs[0].numpy()
301+
302+
if self.tokenizer is not None:
303+
input_tokens = [self.tokenizer.decode([t]) for t in tokens[0]]
304+
elif self.id2token is not None:
305+
input_tokens = [self.id2token[t] for t in tokens[0]]
306+
else:
307+
input_tokens = tokens[0]
297308
explanations.add(
298309
instance=instance.to_str(),
299310
target_label=y[i] if y is not None else None,
300-
tokens=tokens[0] if self.id2token is None else [self.id2token[t] for t in tokens[0]],
311+
tokens=input_tokens,
301312
importance_scores=scores,
302313
)
303314
return explanations

omnixai/explainers/vision/specific/feature_visualization/pytorch/optimizer.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from typing import Union, List
1212
from collections import defaultdict
1313
from ..utils import Objective, FeatureOptimizerMixin
14+
from ..utils import fft_inputs, fft_scale
15+
from .preprocess import fft_images
1416

1517

1618
class FeatureOptimizer(FeatureOptimizerMixin):
@@ -171,6 +173,8 @@ def optimize(
171173
value_normalizer="sigmoid",
172174
value_range=(0.05, 0.95),
173175
init_std=0.01,
176+
use_fft=False,
177+
fft_decay=1.0,
174178
normal_color=False,
175179
save_all_images=False,
176180
verbose=True,
@@ -186,16 +190,36 @@ def optimize(
186190
if not isinstance(regularizers, list):
187191
regularizers = [regularizers]
188192
regularizers = [self._regularize(reg, w) for reg, w in regularizers]
193+
if use_fft:
194+
# Using "normal color" for FFT preconditioning
195+
normal_color = True
189196

190197
device = next(self.model.parameters()).device
191-
inputs = torch.tensor(
192-
np.random.randn(*(self.num_combinations, 3, *image_shape)) * init_std,
193-
dtype=torch.float32,
194-
requires_grad=True,
195-
device=device
196-
)
198+
shape = (self.num_combinations, 3, *image_shape)
199+
if not use_fft:
200+
inputs = torch.tensor(
201+
np.random.randn(*shape) * init_std,
202+
dtype=torch.float32,
203+
requires_grad=True,
204+
device=device
205+
)
206+
normalize = lambda x: self._normalize(
207+
x, value_normalizer, value_range, normal_color)
208+
else:
209+
inputs = torch.tensor(
210+
fft_inputs(*shape, mode="torch", std=init_std),
211+
dtype=torch.float32,
212+
requires_grad=True,
213+
device=device
214+
)
215+
scales = fft_scale(
216+
image_shape[0], image_shape[1], mode="torch", decay_power=fft_decay)
217+
scales = torch.tensor(scales, dtype=torch.complex64, device=device)
218+
normalize = lambda x: self._normalize(
219+
fft_images(image_shape[0], image_shape[1], inputs, scales),
220+
value_normalizer, value_range, normal_color
221+
)
197222
optimizer = torch.optim.Adam([inputs], lr=learning_rate)
198-
normalize = lambda x: self._normalize(x, value_normalizer, value_range, normal_color)
199223

200224
results = []
201225
for i in range(num_iterations):

omnixai/explainers/vision/specific/feature_visualization/pytorch/preprocess.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#
77
import torch
88
import torchvision
9+
from packaging import version
910
from omnixai.preprocessing.base import TransformBase
1011

1112

@@ -121,3 +122,14 @@ def transform(self, x):
121122

122123
def invert(self, x):
123124
raise RuntimeError("`Padding` doesn't support the `invert` function.")
125+
126+
127+
def fft_images(width, height, inputs, scale):
128+
spectrum = torch.complex(inputs[0], inputs[1]) * scale[None, None, :, :]
129+
# Torch 1.7
130+
if version.parse(torch.__version__) < version.parse("1.8"):
131+
x = torch.cat([spectrum.real.unsqueeze(dim=-1), spectrum.imag.unsqueeze(dim=-1)], dim=-1)
132+
image = torch.irfft(x, signal_ndim=2, normalized=False, onesided=False)
133+
else:
134+
image = torch.fft.ifft2(spectrum)
135+
return image / 4.0

omnixai/explainers/vision/specific/feature_visualization/tf/optimizer.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import tensorflow as tf
1010
from typing import Union, List
1111
from ..utils import Objective, FeatureOptimizerMixin
12+
from ..utils import fft_inputs, fft_scale
13+
from .preprocess import fft_images
1214

1315

1416
class FeatureOptimizer(FeatureOptimizerMixin):
@@ -195,6 +197,8 @@ def optimize(
195197
value_normalizer="sigmoid",
196198
value_range=(0.05, 0.95),
197199
init_std=0.01,
200+
use_fft=False,
201+
fft_decay=1.0,
198202
normal_color=False,
199203
save_all_images=False,
200204
verbose=True,
@@ -212,10 +216,25 @@ def optimize(
212216
if not isinstance(regularizers, list):
213217
regularizers = [regularizers]
214218
regularizers = [self._regularize(reg, w) for reg, w in regularizers]
219+
if use_fft:
220+
# Using "normal color" for FFT preconditioning
221+
normal_color = True
215222

216-
inputs = tf.Variable(
217-
tf.random.normal(shape, stddev=init_std, dtype=tf.float32), trainable=True)
218-
normalize = lambda x: self._normalize(x, value_normalizer, value_range, normal_color)
223+
if not use_fft:
224+
inputs = tf.Variable(
225+
tf.random.normal(shape, stddev=init_std, dtype=tf.float32), trainable=True)
226+
normalize = lambda x: self._normalize(
227+
x, value_normalizer, value_range, normal_color)
228+
else:
229+
inputs = tf.Variable(
230+
fft_inputs(shape[0], shape[3], shape[1], shape[2], mode="tf", std=init_std),
231+
trainable=True)
232+
scales = fft_scale(shape[1], shape[2], mode="tf", decay_power=fft_decay)
233+
scales = tf.convert_to_tensor(scales, dtype=tf.complex64)
234+
normalize = lambda x: self._normalize(
235+
fft_images(shape[1], shape[2], inputs, scales),
236+
value_normalizer, value_range, normal_color
237+
)
219238
optimizer = tf.keras.optimizers.Adam(learning_rate)
220239

221240
@tf.function

omnixai/explainers/vision/specific/feature_visualization/tf/preprocess.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,3 +150,11 @@ def transform(self, x):
150150

151151
def invert(self, x):
152152
raise RuntimeError("`Padding` doesn't support the `invert` function.")
153+
154+
155+
def fft_images(width, height, inputs, scale):
156+
spectrum = tf.complex(inputs[0], inputs[1]) * scale
157+
image = tf.signal.irfft2d(spectrum)
158+
image = tf.transpose(image, (0, 2, 3, 1))
159+
image = image[:, :width, :height, :]
160+
return image / 4.0

omnixai/explainers/vision/specific/feature_visualization/utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,27 @@ def _process_objectives(objectives):
6969
labels.append({"type": r["type"], "layer_name": layer_name, "index": indices[i, j]})
7070
names.append(labels)
7171
return results, indices.shape[0], names
72+
73+
74+
def fft_freq(width, height, mode):
75+
freq_x = np.fft.fftfreq(width)[:, None]
76+
if mode == "tf":
77+
cut_off = int(height % 2 == 1)
78+
freq_y = np.fft.fftfreq(height)[:height // 2 + 1 + cut_off]
79+
return np.sqrt(freq_y ** 2 + freq_x ** 2)
80+
else:
81+
freq_y = np.fft.fftfreq(height)
82+
return np.sqrt(freq_y ** 2 + freq_x ** 2)
83+
84+
85+
def fft_scale(width, height, mode, decay_power=1.0):
86+
frequencies = fft_freq(width, height, mode)
87+
scale = 1.0 / np.maximum(frequencies, 1.0 / max(width, height)) ** decay_power
88+
scale = scale * np.sqrt(width * height)
89+
return scale
90+
91+
92+
def fft_inputs(batch_size, channel, width, height, mode, std=0.01):
93+
freq = fft_freq(width, height, mode)
94+
inputs = np.random.randn(*((2, batch_size, channel) + freq.shape)) * std
95+
return inputs.astype(np.float32)

omnixai/explainers/vision/specific/feature_visualization/visualizer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ def explain(
9797
transformers: Pipeline = None,
9898
regularizers: List = None,
9999
image_shape: Tuple = None,
100+
use_fft=False,
101+
fft_decay=1.0,
100102
normal_color: bool = False,
101103
verbose: bool = True,
102104
**kwargs
@@ -113,6 +115,8 @@ def explain(
113115
:param regularizers: A list of regularizers applied on images. Each regularizer is a tupe
114116
`(regularizer_type, weight)` where `regularizer_type` is "l1", "l2" or "tv".
115117
:param image_shape: The customized image shape. If None, the default shape is (224, 224).
118+
:param use_fft: Whether to use fourier preconditioning.
119+
:param fft_decay: The value controlling the allowed energy of the high frequency.
116120
:param normal_color: Whether to map uncorrelated colors to normal colors.
117121
:param verbose: Whether to print the optimization progress.
118122
:return: The optimized images for the objectives.
@@ -155,6 +159,8 @@ def explain(
155159
value_normalizer=value_normalizer,
156160
value_range=value_range,
157161
init_std=init_std,
162+
use_fft=use_fft,
163+
fft_decay=fft_decay,
158164
normal_color=normal_color,
159165
save_all_images=False,
160166
verbose=verbose

omnixai/tests/explainers/feature_visualization/explainer_torch.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def _plot(x):
2929
def test_layer(self):
3030
objectives = [
3131
Objective(
32-
layer=self.model.features[20],
32+
layer=self.model.features[-6],
3333
channel_indices=list(range(5))
3434
)
3535
]
@@ -39,7 +39,8 @@ def test_layer(self):
3939
)
4040
results, names = optimizer.optimize(
4141
num_iterations=300,
42-
image_shape=(224, 224)
42+
image_shape=(224, 224),
43+
use_fft=True
4344
)
4445
for res, name in zip(results[-1], names):
4546
print(name)

omnixai/tests/explainers/feature_visualization/feature_explainer.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@ class TestExplainer(unittest.TestCase):
1515

1616
def setUp(self) -> None:
1717
device = "cuda" if torch.cuda.is_available() else "cpu"
18-
# self.model = models.vgg16(pretrained=True).to(device)
19-
# self.target_layer = self.model.features[20]
20-
self.model = vgg16.VGG16()
21-
self.target_layer = self.model.layers[15]
18+
self.model = models.vgg16(pretrained=True).to(device)
19+
self.target_layer = self.model.features[-6]
20+
# self.model = vgg16.VGG16()
21+
# self.target_layer = self.model.layers[15]
2222

2323
def test(self):
2424
optimizer = FeatureVisualizer(
@@ -27,7 +27,8 @@ def test(self):
2727
)
2828
explanations = optimizer.explain(
2929
num_iterations=300,
30-
image_shape=(224, 224)
30+
image_shape=(224, 224),
31+
use_fft=True
3132
)
3233
explanations.ipython_plot()
3334

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
#
2+
# Copyright (c) 2022 salesforce.com, inc.
3+
# All rights reserved.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6+
#
7+
import unittest
8+
import torch
9+
import tensorflow as tf
10+
from omnixai.explainers.vision.specific.feature_visualization.utils import \
11+
fft_freq, fft_scale, fft_inputs
12+
from omnixai.explainers.vision.specific.feature_visualization.tf.preprocess import \
13+
fft_images as fft_images_tf
14+
from omnixai.explainers.vision.specific.feature_visualization.pytorch.preprocess import \
15+
fft_images as fft_images_torch
16+
17+
18+
class TestFFT(unittest.TestCase):
19+
20+
def test_1(self):
21+
batch_size = 5
22+
channel = 3
23+
width = 10
24+
height = 7
25+
mode = "torch"
26+
27+
freq = fft_freq(width, height, mode)
28+
scale = fft_scale(width, height, mode)
29+
inputs = fft_inputs(batch_size, channel, width, height, mode)
30+
self.assertEqual(freq.shape, (10, 7))
31+
self.assertEqual(scale.shape, (10, 7))
32+
self.assertEqual(inputs.shape, (2, 5, 3, 10, 7))
33+
34+
def test_2(self):
35+
batch_size = 5
36+
channel = 3
37+
width = 10
38+
height = 7
39+
mode = "tf"
40+
41+
freq = fft_freq(width, height, mode)
42+
scale = fft_scale(width, height, mode)
43+
inputs = fft_inputs(batch_size, channel, width, height, mode)
44+
self.assertEqual(freq.shape, (10, 5))
45+
self.assertEqual(scale.shape, (10, 5))
46+
self.assertEqual(inputs.shape, (2, 5, 3, 10, 5))
47+
48+
def test_3(self):
49+
batch_size = 5
50+
channel = 3
51+
width = 10
52+
height = 7
53+
mode = "tf"
54+
55+
scale = fft_scale(width, height, mode)
56+
scale = tf.convert_to_tensor(scale, dtype=tf.complex64)
57+
inputs = fft_inputs(batch_size, channel, width, height, mode)
58+
inputs = tf.convert_to_tensor(inputs)
59+
60+
images = fft_images_tf(width, height, inputs, scale)
61+
self.assertEqual(images.shape, (5, 10, 7, 3))
62+
63+
def test_4(self):
64+
batch_size = 5
65+
channel = 3
66+
width = 10
67+
height = 7
68+
mode = "torch"
69+
70+
scale = fft_scale(width, height, mode)
71+
scale = torch.tensor(scale, dtype=torch.complex64)
72+
inputs = fft_inputs(batch_size, channel, width, height, mode)
73+
inputs = torch.tensor(inputs, dtype=torch.float32)
74+
75+
images = fft_images_torch(width, height, inputs, scale)
76+
self.assertEqual(images.shape, (5, 3, 10, 7))
77+
78+
79+
if __name__ == "__main__":
80+
unittest.main()

0 commit comments

Comments
 (0)