A generalized implementation of Selvaraju et al.'s Grad-CAM for the Flax neural network library.
This library works for any convolutional neural network written in Flax, that has image input.
First install it with:
pip install -U git+https://github.com/codymlewis/flax_gradcam.gitThen make sure import fgradcam to use this library.
Finally, there are three lines of code needed to compute the Grad-CAM heatmap and plot it. The first line is added to the Flax linen module, after the convolutional layer that you want to analyse:
x = fgradcam.observe(self, x)With that in place, after training the model, we compute the Grad-CAM heatmaps on the desired samples X with:
heatmaps = fgradcam.compute(model, variables, X)Finally, a heatmap can be visualized with:
fgradcam.plot(X[0], heatmaps[0])
plt.show() # Assuming matplotlib.pyplot was imported as pltA full sample is shown in samples/cnn.py, in addition to an example of performing the Grad-CAM computation on a pretrained
model in samples/transfer.ipynb