Open
Description
Got a negative loss value while was training a CLIP model: am I doing something wrong or it is a loss function bug?
Here is my code
import torch
from torch.utils.data import DataLoader
from dalle2_pytorch import CLIP
from dalle2_pytorch.tokenizer import SimpleTokenizer
from dataset import TextImgDataset
from tqdm import tqdm
from torch.optim.adamw import AdamW
clip = CLIP(
dim_text=512,
dim_image=32,
dim_latent=512,
num_text_tokens=49408,
text_enc_depth=1,
text_seq_len=256,
text_heads=8,
visual_enc_depth=1,
visual_image_size=256,
visual_patch_size=32,
visual_heads=8,
use_all_token_embeds=True,
decoupled_contrastive_learning=True,
extra_latent_projection=True,
use_visual_ssl=True,
visual_ssl_type='simclr',
use_mlm=False,
text_ssl_loss_weight=0.05,
image_ssl_loss_weight=0.05
).cuda()
optim = AdamW(clip.parameters(), lr=3e-4)
dataloader = DataLoader(dataset=TextImgDataset(
'hf://datasets/pranked03/flowers-blip-captions/data/train-00000-of-00001-f41d4839cc8f6449.parquet'), batch_size=4,
shuffle=False)
t = SimpleTokenizer()
# Early stopping parameters
patience = 10
best_loss = float('inf')
trigger_times = 0
for epoch in range(1, 500):
losses = []
for image, text in tqdm(dataloader, desc=f'epoch {epoch}'):
optim.zero_grad()
loss = clip(
t.tokenize(text).cuda(),
image.cuda(),
return_loss=True
)
loss.backward()
losses.append(loss.item())
optim.step()
epoch_loss = sum(losses) / len(losses)
print(f"epoch {epoch}, loss: {epoch_loss}")
# Check for early stopping
if epoch_loss < best_loss:
best_loss = epoch_loss
trigger_times = 0
torch.save(clip.state_dict(), 'clip.pt') # Save the model when it improves
else:
trigger_times += 1
print(f"Trigger times: {trigger_times}")
if trigger_times >= patience:
print("Early stopping!")
break
Custom dataset class
import torch
from torch.utils.data import Dataset
from torchvision import transforms as T
import pandas as pd
from PIL import Image
import io
# Dataset class, returns images and corresponding textual captions
class TextImgDataset(Dataset):
def __init__(self, fp: str):
self.df = pd.read_parquet(fp)
self.transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((32, 32)),
T.ToTensor(),
])
def __len__(self) -> int:
return self.df.shape[0]
def __getitem__(self, idx) -> (torch.Tensor, str):
row = self.df.iloc[idx]
img_bytes = io.BytesIO(row['image']['bytes'])
image = Image.open(img_bytes)
image_tensor = self.transform(image)
caption = row['text']
return image_tensor, caption
Metadata
Metadata
Assignees
Labels
No labels