Skip to content

Model inference outside of composite context #200

@MikiFER

Description

@MikiFER

Hi @chr5tphr ,

I'm creating a new issue regarding the inference of a model.
I noticed that when I infer a model outside of (before) composite context (which has appropriate model cannonizer) I do not obtain the same attribution as when the inference is done inside of the context. This has me concerned because in order to properly learn batch norm's parameters inference should be done outside of the context because context effectively creates identity out of batch norm so inference inside of the context would never update it's parameter values. Is there something I'm not understanding here correctly?

Here is a code snippet I used to validate that attribution is not the same. Same test was also conducted on vgg16 model and yielded the same result.

import torch
from torchvision.models import resnet18

from zennit.composites import EpsilonPlusFlat
from zennit.torchvision import ResNetCanonizer

from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop
from torchvision.transforms import ToTensor, Normalize

import matplotlib.pyplot as plt

# define the base image transform
transform_img = Compose([
    Resize(256),
    CenterCrop(224),
])
# define the normalization transform
transform_norm = Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
# define the full tensor transform
transform = Compose([
    transform_img,
    ToTensor(),
    transform_norm,
])
# load the image
image = Image.open('dornbusch-lighthouse.jpg')
# transform the PIL image and insert a batch-dimension
data = transform(image)[None]
data.requires_grad = True
# define target
target = torch.eye(1000)[[437]]


model = resnet18()

canonizers = [ResNetCanonizer()]
composite = EpsilonPlusFlat(canonizers=canonizers)

# Inference before context
model.eval()  # Put model in eval so batch-norm is frozen
model_out_before = model(data)
with composite.context(model) as modified_model:
    attribution_before, = torch.autograd.grad(model_out_before, data, target)

# Inference inside context
with composite.context(model) as modified_model:
    model_out_in = modified_model(data)
    attribution_in, = torch.autograd.grad(model_out_in, data, target)

relevance_before = attribution_before.cpu().sum(1).squeeze(0).numpy()
relevance_in = attribution_in.cpu().sum(1).squeeze(0).numpy()

plt.figure(figsize=(15, 5))
plt.subplot(1,3,1)
plt.imshow(transform_img(image))
plt.axis('off')
plt.subplot(1,3,2)
plt.imshow(relevance_before)
plt.axis('off')
plt.subplot(1,3,3)
plt.imshow(relevance_in)
plt.axis('off')
plt.show()

Metadata

Metadata

Assignees

No one assigned

    Labels

    questionFurther information is requestedtestingImprovements, additions, or issues with tests

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions