Skip to content

Potential issue with learning BatchNorm parameters #204

@MikiFER

Description

@MikiFER

Hi, in my project I have encountered and issue that I'm not sure if it's caused by invalid usage of the library or there is some bug in the library code. I cannot provide minimal code for reproduction because bug occurs during training so I will describe it as best as I can.
Pseudo-code for my training looks something like:

for i in range(number_epochs):
    # train loop
    model.train()
    for input, gt in train_dataloader:
        with composite.context(model) as modified_model:
            model_out = modified_model(input)
            task_loss = get_task_loss(model_out, gt)
            gt_maps = torch.autograd.grad(model_out, input, torch.ones_like(labels), retain_graph=True)[0].sum(1)
            salience_loss = get_saliance_loss(model_out, gt)
            total_loss = task_loss + salience_loss
        
        total_loss.backward()
        self.optimizer.step()
        self.optimizer.zero_grad()
            
    #validation loop
    model.eval()
    for input, gt in train_dataloader:
        with composite.context(model) as modified_model:
            model_out = modified_model(input)
            ...

For model I have tested VGG16, Resnet34 and VGG16_bn with appropriate canonizers, and for composite I have used EpsilonPlusFlat. All models have their heads changed to have 20 outputs, and are randomly initialized. I have noticed that models with BatchNorm have significant difference between output when in train mode and when in eval.

I have logged the sum of output during training to show this for different models.

For VGG16 we can see that output sums have around the same order of magnitude which is expected:
image

For ResNet34 we see drastic change in output sums, around 4 orders of magnitudes difference
image

For VGG16_bn we again see difference in output sums but difference is "only" around 1 order of magnitude:
image

I see that this behaviour is very strange but it all points to something being wrong with BatchNorm.
Version of Zennit I'm using is 0.5.2.dev5.
I would really appreciate your help regarding this one.
Thanks in advance.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingrequires verificationIt is unclear, how the issue was caused, and it requires more details or a minimal example.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions