-
Notifications
You must be signed in to change notification settings - Fork 91
[bug report] SRL training mechanism is wrong: validation set is used to update weights during training. #48
Description
Describe the bug
I hope I'm wrong but it seems that the current training mechanism uses the validation set to update the weights during training which is unwanted. The problem comes from the fact that we calculate the gradient by loss.backward() in both training and validation.
The curring training mechanism can be reduce to the following pseudo code (in srl-zoo/models/learner.py)
for sample in dataloader:
if sample in train_set: ## training mode
model.train()
else: ## validation mode
model.eval()
optimizer.zero_grad()
Y_pred = model(X)
loss = compute_loss(Y_pred, Y)
loss.backward() # <-- [Wrong] We backpropagate the gradient in both train/valid mode
if sample in train_set:
optimizer.step()
loss = loss.item()
else:
# We don't update the weights at this iteration, but the gradients of loss on validation
# samples are calculated and stored and will be used the next time we call optimizer.step()
loss = loss.item()The common way to validate a model in Pytorch should look like the following:
for epoch in range(epochs):
model.train()
for sample in dataloader_train:
optimizer.zero_grad()
loss = compute_loss(...)
loss.backward()
loss = loss.item() ## release tensor
model.eval()
with torch.no_grad(): ## It mandatory to call both model.eval() and torch.no_grad()
for sample in dataloader_valid:
loss = compute_loss(...)
loss = loss.item() ## release tensorBut in the toolbox, with torch.no_grad() is not called during validation ! Besides, there are the other downsides when calling loss.backward() in the validation mode. Not only it's wrong but also it's a waste of time to do backpropagation when we don't need the gradient.
Code example
I write a code to mimic the training of srl-zoo, and demonstrate that the current training mechanism will use the gradient of loss on the validation data to update the model weights.
import numpy as np
from time import time
import torch
try:
from torchsummary import summary ## pip install torchsummary
except:
pass
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
# self.layer = nn.Linear(10, 1)
self.model = nn.Sequential(
nn.Linear(10, 1)
)
def forward(self, x):
# x = self.layer(x)
x = self.model(x)
return x
if __name__=="__main__":
print("Start")
model = MyModel()
try:
summary(model, (10,))
except:
pass
CASE = 0
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
X = np.random.rand(100, 10).astype(np.float32)
X = torch.tensor(X).to(device)
# Y = np.random.rand(100, 1).astype(np.float32) # torch.ones((100, 1))
# Y = torch.tensor(Y).to(device)
Y = torch.ones((100, 1)).to(device)
Z = torch.zeros((100, 1)).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.5, 0.9))
model.to(device)
loss = 0
val_loss = 0
for ind in range(10000):
model.train()
optimizer.zero_grad()
X_pred = model(X)
loss = nn.MSELoss()(X_pred, Y)
loss.backward()
optimizer.step()
loss = loss.item()
if CASE == 0:
## No validation
pass
else:
model.eval()
optimizer.zero_grad()
Z_pred = model(X)
loss = nn.MSELoss()(Z_pred, Z)
loss.backward() # <-- It's wrong !
# no optimizer.step() here !
val_loss = loss.item()
print("\r Iter {} Train Loss: {:.10f} | Val Loss: {:.10f}".format(ind+1, loss, val_loss), end='')
print()By switching between CASE = 0 (no validation) or CASE = 1 (with wrong validation mechanism), you will see one linear layer is sufficient to learn the task (CASE = 0 e.g. Train Loss: 0.0000001010). However, in the CASE = 1, the gradient of validation set will affect the training and the model will not converge (e.g. Train Loss: 1.0033100843 | Val Loss: 1.0033100843).
SOLUTION
It's easy, just remove loss.backward() in the validation mode and add torch.set_grad_enabled(False/True) at the beginning/end of the validation. This will also provide around 5-10% speed-up depending on the model.