Skip to content

Negative loss on CLIP tranining #321

Open
@u1ug

Description

@u1ug

Got a negative loss value while was training a CLIP model: am I doing something wrong or it is a loss function bug?

Here is my code

import torch
from torch.utils.data import DataLoader
from dalle2_pytorch import CLIP
from dalle2_pytorch.tokenizer import SimpleTokenizer
from dataset import TextImgDataset
from tqdm import tqdm
from torch.optim.adamw import AdamW

clip = CLIP(
    dim_text=512,
    dim_image=32,
    dim_latent=512,
    num_text_tokens=49408,
    text_enc_depth=1,
    text_seq_len=256,
    text_heads=8,
    visual_enc_depth=1,
    visual_image_size=256,
    visual_patch_size=32,
    visual_heads=8,
    use_all_token_embeds=True,
    decoupled_contrastive_learning=True,
    extra_latent_projection=True,
    use_visual_ssl=True,
    visual_ssl_type='simclr',
    use_mlm=False,
    text_ssl_loss_weight=0.05,
    image_ssl_loss_weight=0.05
).cuda()

optim = AdamW(clip.parameters(), lr=3e-4)

dataloader = DataLoader(dataset=TextImgDataset(
    'hf://datasets/pranked03/flowers-blip-captions/data/train-00000-of-00001-f41d4839cc8f6449.parquet'), batch_size=4,
                        shuffle=False)
t = SimpleTokenizer()

# Early stopping parameters
patience = 10
best_loss = float('inf')
trigger_times = 0

for epoch in range(1, 500):
    losses = []
    for image, text in tqdm(dataloader, desc=f'epoch {epoch}'):
        optim.zero_grad()
        loss = clip(
            t.tokenize(text).cuda(),
            image.cuda(),
            return_loss=True
        )
        loss.backward()
        losses.append(loss.item())
        optim.step()

    epoch_loss = sum(losses) / len(losses)
    print(f"epoch {epoch}, loss: {epoch_loss}")

    # Check for early stopping
    if epoch_loss < best_loss:
        best_loss = epoch_loss
        trigger_times = 0
        torch.save(clip.state_dict(), 'clip.pt')  # Save the model when it improves
    else:
        trigger_times += 1
        print(f"Trigger times: {trigger_times}")

        if trigger_times >= patience:
            print("Early stopping!")
            break

Custom dataset class

import torch
from torch.utils.data import Dataset
from torchvision import transforms as T
import pandas as pd
from PIL import Image
import io


# Dataset class, returns images and corresponding textual captions
class TextImgDataset(Dataset):
    def __init__(self, fp: str):
        self.df = pd.read_parquet(fp)
        self.transform = T.Compose([
            T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
            T.Resize((32, 32)),
            T.ToTensor(),
        ])

    def __len__(self) -> int:
        return self.df.shape[0]

    def __getitem__(self, idx) -> (torch.Tensor, str):
        row = self.df.iloc[idx]
        img_bytes = io.BytesIO(row['image']['bytes'])
        image = Image.open(img_bytes)
        image_tensor = self.transform(image)
        caption = row['text']

        return image_tensor, caption

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions