Skip to content

feat(logging): add positive/negative cosine-similarity and embedding … #52

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 82 additions & 19 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
import torchvision
import argparse
import torch.nn.functional as F # For cosine similarity

# distributed training
import torch.distributed as dist
Expand All @@ -16,26 +17,26 @@
# SimCLR
from simclr import SimCLR
from simclr.modules import NT_Xent, get_resnet
from simclr.modules.transformations import TransformsSimCLR
from simclr.modules.transformations import TransformsSimCLR
from simclr.modules.sync_batchnorm import convert_model

from model import load_optimizer, save_model
from utils import yaml_config_hook


def train(args, train_loader, model, criterion, optimizer, writer):
loss_epoch = 0
for step, ((x_i, x_j), _) in enumerate(train_loader):
optimizer.zero_grad()
x_i = x_i.cuda(non_blocking=True)
x_j = x_j.cuda(non_blocking=True)

# positive pair, with encoding
# Positive pair, with encoding
h_i, h_j, z_i, z_j = model(x_i, x_j)

loss = criterion(z_i, z_j)
loss.backward()


optimizer.step()

if dist.is_available() and dist.is_initialized():
Expand All @@ -45,10 +46,41 @@ def train(args, train_loader, model, criterion, optimizer, writer):
if args.nr == 0 and step % 50 == 0:
print(f"Step [{step}/{len(train_loader)}]\t Loss: {loss.item()}")

# Log the loss for every step.
if args.nr == 0:
writer.add_scalar("Loss/train_epoch", loss.item(), args.global_step)
writer.add_scalar("Loss/train_epoch_step", loss.item(), args.global_step)
args.global_step += 1

# Log augmented images every 100 steps.
if args.nr == 0 and step % 100 == 0:
grid_xi = torchvision.utils.make_grid(x_i[:16], nrow=4, normalize=True)
grid_xj = torchvision.utils.make_grid(x_j[:16], nrow=4, normalize=True)
writer.add_image("Augmented Images/x_i", grid_xi, args.global_step)
writer.add_image("Augmented Images/x_j", grid_xj, args.global_step)

# Compute and log cosine similarity between positive and negative pair representations.
with torch.no_grad():
# Positive cosine similarities.
cos_sim = F.cosine_similarity(z_i, z_j, dim=-1)
avg_cos_sim = cos_sim.mean().item()
writer.add_scalar("Cosine Similarity/avg_positive", avg_cos_sim, args.global_step)
writer.add_histogram("Cosine Similarity/hist_positive", cos_sim, args.global_step)

# Compute negative similarities.
norm_z_i = F.normalize(z_i, dim=1)
norm_z_j = F.normalize(z_j, dim=1)
cosine_matrix = torch.mm(norm_z_i, norm_z_j.t()) # (batch_size, batch_size)
# Mask the diagonal (positive pairs)
mask = torch.eye(cosine_matrix.size(0), dtype=torch.bool, device=cosine_matrix.device)
negative_sim = cosine_matrix[~mask].view(cosine_matrix.size(0), -1)
avg_cos_sim_neg = negative_sim.mean().item()
writer.add_scalar("Cosine Similarity/avg_negative", avg_cos_sim_neg, args.global_step)
writer.add_histogram("Cosine Similarity/hist_negative", negative_sim, args.global_step)

# Log embedding standard deviation as a collapse indicator.
std_z_i = norm_z_i.std(dim=0).mean().item()
writer.add_scalar("Embeddings/std_dev", std_z_i, args.global_step)

loss_epoch += loss.item()
return loss_epoch

Expand All @@ -63,6 +95,7 @@ def main(gpu, args):
torch.manual_seed(args.seed)
np.random.seed(args.seed)

# Data loading: support for STL10, CIFAR10, and a custom dataset (e.g., dogs vs. cats)
if args.dataset == "STL10":
train_dataset = torchvision.datasets.STL10(
args.dataset_dir,
Expand All @@ -76,9 +109,15 @@ def main(gpu, args):
download=True,
transform=TransformsSimCLR(size=args.image_size),
)
elif args.dataset == "custom":
train_dataset = torchvision.datasets.ImageFolder(
os.path.join(args.dataset_dir, "dataset/training_set/"),
transform=TransformsSimCLR(size=args.image_size)
)
else:
raise NotImplementedError

# Create a DistributedSampler if using more than one node.
if args.nodes > 1:
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset, num_replicas=args.world_size, rank=rank, shuffle=True
Expand All @@ -95,11 +134,11 @@ def main(gpu, args):
sampler=train_sampler,
)

# initialize ResNet
# Initialize ResNet
encoder = get_resnet(args.resnet, pretrained=False)
n_features = encoder.fc.in_features # get dimensions of fc layer
n_features = encoder.fc.in_features

# initialize model
# Initialize model
model = SimCLR(encoder, args.projection_dim, n_features)
if args.reload:
model_fp = os.path.join(
Expand All @@ -108,7 +147,7 @@ def main(gpu, args):
model.load_state_dict(torch.load(model_fp, map_location=args.device.type))
model = model.to(args.device)

# optimizer / loss
# Setup optimizer and NT-Xent loss criterion.
optimizer, scheduler = load_optimizer(args, model)
criterion = NT_Xent(args.batch_size, args.temperature, args.world_size)

Expand Down Expand Up @@ -143,27 +182,53 @@ def main(gpu, args):
save_model(args, model, optimizer)

if args.nr == 0:
writer.add_scalar("Loss/train", loss_epoch / len(train_loader), epoch)
avg_loss = loss_epoch / len(train_loader)
writer.add_scalar("Loss/train", avg_loss, epoch)
writer.add_scalar("Misc/learning_rate", lr, epoch)
print(
f"Epoch [{epoch}/{args.epochs}]\t Loss: {loss_epoch / len(train_loader)}\t lr: {round(lr, 5)}"
)
print(f"Epoch [{epoch}/{args.epochs}]\t Loss: {avg_loss}\t lr: {round(lr, 5)}")
args.current_epoch += 1

## end training
# Log embeddings for a small subset of images.
model.eval()
try:
sample_batch = next(iter(train_loader))
# Use one view (x_i) for embedding logging.
sample_images, sample_labels = sample_batch[0][0], sample_batch[1]
sample_images = sample_images.cuda(non_blocking=True)

with torch.no_grad():
embeddings = model.encoder(sample_images)
# Convert numerical labels to metadata (e.g., "dog" or "cat")
#print("Shape of sample images:", sample_images.shape)
#assert sample_images.ndim == 4, "sample_images must be a 4D tensor [N, C, H, W]"
cifar10_classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
metadata = [cifar10_classes[label] for label in sample_labels]
#metadata = ["dog" if label == 1 else "cat" for label in sample_labels]
writer.add_embedding(embeddings, metadata=metadata, label_img=sample_images, global_step=epoch, tag="embeddings")
except Exception as e:
print(f"Error during embedding logging: {e}")
model.train()

# Log gradient histograms for all parameters.
for name, param in model.named_parameters():
if param.grad is not None:
writer.add_histogram(f'{name}.grad', param.grad, epoch)

# End training: save final model checkpoint.
save_model(args, model, optimizer)

if writer:
writer.close()


if __name__ == "__main__":

parser = argparse.ArgumentParser(description="SimCLR")
config = yaml_config_hook("./config/config.yaml")
for k, v in config.items():
parser.add_argument(f"--{k}", default=v, type=type(v))

args = parser.parse_args()

# Master address for distributed data parallel
# Set master address for distributed training (if applicable).
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "8000"

Expand All @@ -175,9 +240,7 @@ def main(gpu, args):
args.world_size = args.gpus * args.nodes

if args.nodes > 1:
print(
f"Training with {args.nodes} nodes, waiting until all nodes join before starting training"
)
print(f"Training with {args.nodes} nodes, waiting until all nodes join before starting training")
mp.spawn(main, args=(args,), nprocs=args.gpus, join=True)
else:
main(0, args)