Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tweak image and text classification LIME tutorial #1518

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 20 additions & 20 deletions tutorials/Image_and_Text_Classification_LIME.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"This dataset provides an addional segmentation mask along with every image. Compared with inspecting each pixel, the segments (or \"super-pixels\") are semantically more intuitive for human to perceive. We will discuss more in section 1.3.\n",
"This dataset provides an additional segmentation mask along with every image. Compared with inspecting each pixel, the segments (or \"super-pixels\") are semantically more intuitive for humans to perceive. We will discuss more in section 1.3.\n",
"\n",
"Let's pick one example to see how the image and corresponding mask look like. Here we choose an image with more than one segments besides background so that we can compare each segment's impact on the classification."
"Let's pick one example to see how the image and corresponding mask look like. Here we choose an image with more than one segment besides the background, so that we can compare each segment's impact on the classification."
]
},
{
Expand Down Expand Up @@ -280,9 +280,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"In this section, we will bring in LIME from Captum to analyze how the Resnet made above prediction based on the sample image.\n",
"In this section, we will bring in LIME from Captum to analyze how the Resnet made the above prediction based on the sample image.\n",
"\n",
"Like many other Captum algorithms, Lime also support analyzing a number of input features together as a group. This is very useful when dealing with images, where each color channel in each pixel is an input feature. Such group is also refered as \"super-pixel\". To define our desired groups over input features, all we need is to provide a feature mask.\n",
"Like many other Captum algorithms, Lime also supports analyzing a number of input features together as a group. This is very useful when dealing with images, where each color channel in each pixel is an input feature. Such a group is also refered as \"super-pixel\". To define our desired groups over input features, all we need is to provide a feature mask.\n",
"\n",
"In case of an image input, the feature mask is a 2D image of the same size, where each pixel in the mask indicates the feature group it belongs to via an integer value. Pixels of the same value define a group.\n",
"\n",
Expand Down Expand Up @@ -319,7 +319,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"It is time to configure our Lime algorithm now. Essentially, Lime trains an interpretable surrogate model to simulate the target model's predictions. So, building an appropriate interpretable model is the most critical step in Lime. Fortunately, Captum has provided many most common interpretable models to save the efforts. We will demonstrate the usages of Linear Regression and Linear Lasso. Another important factor is the similarity function. Because Lime aims to explain the local behavior of an example, it will reweight the training samples according to their similarity distances. By default, Captum's Lime uses the exponential kernel on top of the consine distance. We will change to euclidean distance instead which is more popular in vision. "
"It is time to configure our Lime algorithm. Essentially, Lime trains an interpretable surrogate model to simulate the target model's predictions. So, building an appropriate interpretable model is the most critical step in Lime. Fortunately, Captum has provided many of the most common interpretable models to save the efforts. We will demonstrate the usages of Linear Regression and Linear Lasso. Another important factor is the similarity function. Because Lime aims to explain the local behavior of an example, it will reweight the training samples according to their similarity distances. By default, Captum's Lime uses the exponential kernel on top of the cosine distance. We will change to Euclidean distance instead which is more popular in vision. "
]
},
{
Expand Down Expand Up @@ -421,11 +421,11 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"The result looks decent: the television segment does demonstrate strongest positive correlation with the prediction, while the chairs has relatively trivial impact and the border slightly shows negative contribution.\n",
"The result looks decent: the television segment does demonstrate strongest positive correlation with the prediction, while the chairs have relatively trivial impact and the border slightly shows negative contribution.\n",
"\n",
"However, we can further improve this result. One desired characteristic of interpretability is the ease for human to comprehend. We should help reduce the noisy interference and emphisze the real influential features. In our case, all features more or less show some influences. Adding lasso regularization to the interpretable model can effectively help us filter them. Therefore, let us try Linear Lasso with a fit coefficient `alpha`. For all built-in sklearn wrapper model, you can directly pass any sklearn supported arguments.\n",
"However, we can further improve this result. One desired characteristic of interpretability is the ease for humans to comprehend. We should help reduce the noisy interference and emphisze the real influential features. In our case, all features more or less show some influences. Adding lasso regularization to the interpretable model can effectively help us filter them. Therefore, let us try Linear Lasso with a fit coefficient `alpha`. For all built-in sklearn wrapper models, you can directly pass any sklearn supported arguments.\n",
"\n",
"Moreover, since our example only has 4 segments, there are just 16 possible combinations of interpretable representations in total. So we can exhaust them instead random sampling. The `Lime` class's argument `perturb_func` allows us to pass a generator function yielding samples. We will create the generator function iterating the combinations and set the `n_samples` to its exact length."
"Moreover, since our example only has 4 segments, there are just 16 possible combinations of interpretable representations in total. So we can exhaust them instead of random sampling. The `Lime` class's argument `perturb_func` allows us to pass a generator function yielding samples. We will create the generator function iterating the combinations and set the `n_samples` to its exact length."
]
},
{
Expand Down Expand Up @@ -493,7 +493,7 @@
"source": [
"As we can see, the new attribution result removes the chairs and border with the help of Lasso.\n",
"\n",
"Another interesting question to explore is if the model also recognize the chairs in the image. To answer it, we will use the most related label `rocking_chair` from ImageNet as the target, whose label index is `765`. We can check how confident the model feels about the alternative object."
"Another interesting question to explore is if the model also recognizes the chairs in the image. To answer this, we will use the most related label `rocking_chair` from ImageNet as the target, whose label index is `765`. We can check how confident the model feels about the alternative object."
]
},
{
Expand Down Expand Up @@ -574,7 +574,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"As shown in the heat map, our ResNet does present right belief about the chair segment. However, it gets hindered by the television segment in the foreground. This may also explain why the model feels less confident about the chairs than the television."
"As shown in the heat map, our ResNet does present the right belief about the chair segment. However, it gets hindered by the television segment in the foreground. This may also explain why the model feels less confident about the chairs than the television."
]
},
{
Expand All @@ -588,9 +588,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"We have already learned how to use Captum's Lime. This section will additionally dive into the internal sampling process to give interested readers an overview of what happens underneath. The goal of the sampling process is to collect a set of training data for the surrogate model. Every data point consists of three parts: interpretable input, model predicted label, and similarity weight. We will roughly illustrate how Lime achieve each of them behind the scene.\n",
"We have already learned how to use Captum's Lime. This section will additionally dive into the internal sampling process to give interested readers an overview of what happens underneath. The goal of the sampling process is to collect a set of training data for the surrogate model. Every data point consists of three parts: interpretable input, model predicted label, and similarity weight. We will roughly illustrate how Lime achieves each of them behind the scenes.\n",
"\n",
"As we mentioned before, Lime samples data from the interpretable space. By default, Lime uses the presense or absense of the given mask groups as interpretable features. In our example, facing the above image of 4 segments, the interpretable representation is therefore a binary vector of 4 values indicating if each segment is present or absent. This is why we know there are only 16 possible interpretable representations and can exhaust them with our `iter_combinations`. Lime will keep calling its `perturb_func` to get the sample interpretable inputs. Let us simulate this step and give us a such interpretable input."
"As we mentioned before, Lime samples data from the interpretable space. By default, Lime uses the presence or absence of the given mask groups as interpretable features. In our example, facing the above image of 4 segments, the interpretable representation is therefore a binary vector of 4 values indicating if each segment is present or absent. This is why we know there are only 16 possible interpretable representations and can exhaust them with our `iter_combinations`. Lime will keep calling its `perturb_func` to get the sample interpretable inputs. Let us simulate this step and give us such an interpretable input."
]
},
{
Expand Down Expand Up @@ -622,7 +622,7 @@
"source": [
"Our input sample `[1, 1, 0, 1]` means the third segment (television) is absent while other three segments stay. \n",
"\n",
"In order to find out what the target ImageNet's prediction is for this sample, Lime needs to convert it from interpretable space back to the original example space, i.e., the image space. The transformation takes the original example input and modify it by setting the features of the absent groups to a baseline value which is `0` by default. The transformation function is called `from_interp_rep_transform` under Lime. We will run it manually here to get the pertubed image input and then visualize what it looks like."
"In order to find out what the target ImageNet's prediction is for this sample, Lime needs to convert it from interpretable space back to the original example space, i.e., the image space. The transformation takes the original example input and modifies it by setting the features of the absent groups to a baseline value which is `0` by default. The transformation function is called `from_interp_rep_transform` under Lime. We will run it manually here to get the pertubed image input and then visualize what it looks like."
]
},
{
Expand Down Expand Up @@ -666,7 +666,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"As shown above, compared with the original image, the absent feature, i.e., the television segment, gets masked in the perturbed image, while the rest present features stay unchanged. With the perturbed image, Lime is able to find out the model's prediction. Let us still use \"television\" as our attribution target, so the label of perturbed sample is the value of the model's prediction on \"television\". Just for curiosity, we can also check how the model's prediction changes with the perturbation."
"As shown above, compared with the original image, the absent feature, i.e., the television segment, gets masked in the perturbed image, while the other present features stay unchanged. With the perturbed image, Lime is able to find out the model's prediction. Let us still use \"television\" as our attribution target, so the label of perturbed sample is the value of the model's prediction on \"television\". Just for curiosity, we can also check how the model's prediction changes with the perturbation."
]
},
{
Expand Down Expand Up @@ -706,9 +706,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Reasonably, our ImageNet no longer feel confident about classifying the image as television.\n",
"Reasonably, our ImageNet no longer feel confident about classifying the image as a television.\n",
"\n",
"At last, because Lime focuses on the local interpretability, it will calculate the similarity between the perturbed and original images to reweight the loss of this data point. Note the calculation is based on the input space instead of the interpretable space. This step is simply passing the two image tensors into the given `similarity_func` argument which is the exponential kernel of euclidean distance in our case."
"At last, because Lime focuses on the local interpretability, it will calculate the similarity between the perturbed and original images to reweight the loss of this data point. Note the calculation is based on the input space instead of the interpretable space. This step is simply passing the two image tensors into the given `similarity_func` argument which is the exponential kernel of Euclidean distance in our case."
]
},
{
Expand All @@ -735,7 +735,7 @@
"source": [
"This is basically how Lime create a single training data point of `sample_interp_inp`, `sample_label`, and `sample_similarity`. By repeating this process `n_samples` times, it collects a dataset to train the interpretable model.\n",
"\n",
"Worth noting that the steps we showed in this section is an example based on our Lime instance configured above. The logic of each step can be customized, especially with `LimeBase` class which will be demonstrated in Section 2."
"It is Worth noting that the steps we showed in this section is an example based on our Lime instance configured above. The logic of each step can be customized, especially with `LimeBase` class which will be demonstrated in Section 2."
]
},
{
Expand All @@ -749,7 +749,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"In this section, we will take use of a news subject classification example to demonstrate more customizable functions in Lime. We will train a simple embedding-bag classifier on AG_NEWS dataset and analyze its understanding of words."
"In this section, we will take use of a news subject classification example to demonstrate more customizable functions in Lime. We will train a simple embedding-bag classifier on the AG_NEWS dataset and analyze its understanding of words."
]
},
{
Expand Down Expand Up @@ -994,7 +994,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally, it is time to bring back Lime to inspect how the model makes the prediction. However, we will use the more customizable `LimeBase` class this time which is also the low-level implementation powering the `Lime` class we used before. The `Lime` class is opinionated when creating features from perturbed binary interpretable representations. It can only set the \"absense\" features to some baseline values while keeping other \"presense\" features. This is not what we want in this case. For text, the interpretable representation is a binary vector indicating if the word of each position is present or not. The corresponding text input should literally remove the absent words so our embedding-bag can calculate the average embeddings of the left words. Setting them to any baselines will pollute the calculation and moreover, our embedding-bag does not have common baseline tokens like `<padding>` at all. Therefore, we have to use `LimeBase` to customize the conversion logic through the `from_interp_rep_transform` argument.\n",
"Finally, it is time to bring back Lime to inspect how the model makes the prediction. However, we will use the more customizable `LimeBase` class this time which is also the low-level implementation powering the `Lime` class we used before. The `Lime` class is opinionated when creating features from perturbed binary interpretable representations. It can only set the \"absence\" features to some baseline values while keeping other \"presence\" features. This is not what we want in this case. For text, the interpretable representation is a binary vector indicating if the word of each position is present or not. The corresponding text input should literally remove the absent words so our embedding-bag can calculate the average embeddings of the left words. Setting them to any baselines will pollute the calculation and moreover, our embedding-bag does not have common baseline tokens like `<padding>` at all. Therefore, we have to use `LimeBase` to customize the conversion logic through the `from_interp_rep_transform` argument.\n",
"\n",
"`LimeBase` is not opinionated at all so we have to define every piece manually. Let us talk about them in order:\n",
"- `forward_func`, the forward function of the model. Notice we cannot pass our model directly since Captum always assumes the first dimension is batch while our embedding-bag requires flattened indices. So we will add the dummy dimension later when calling `attribute` and make a wrapper here to remove the dummy dimension before giving to our model.\n",
Expand Down Expand Up @@ -1028,7 +1028,7 @@
" probs = torch.ones_like(text) * 0.5\n",
" return torch.bernoulli(probs).long()\n",
"\n",
"# remove absenst token based on the intepretable representation sample\n",
"# remove absent token based on the intepretable representation sample\n",
"def interp_to_input(interp_sample, original_input, **kwargs):\n",
" return original_input[interp_sample.bool()].view(original_input.size(0), -1)\n",
"\n",
Expand Down