-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgenerate_embeddings.py
More file actions
76 lines (56 loc) · 2.25 KB
/
generate_embeddings.py
File metadata and controls
76 lines (56 loc) · 2.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
from torch.utils.data import DataLoader
from src.datasets.polyvore_embed import DatasetArguments, PolyvoreDataset
from src.models.embedder import CLIPEmbeddingModel
from src.models.recommender import RecommendationModel
from src.models.load import load_model
from src.loss.focal_loss import focal_loss
from src.utils.utils import save_model
import os
import wandb
import numpy as np
from bitsandbytes.optim import AdamW8bit
from torch.optim.lr_scheduler import OneCycleLR
from torch.optim import AdamW
from tqdm import tqdm
from datetime import datetime
from dataclasses import dataclass
from sklearn.metrics import roc_auc_score
import pickle
from model_args import Args
args = Args()
# Embedding Setting
args.num_workers = 4
args.batch_size = 512
args.with_cuda = True
def generate(model, dataloader, device):
type_str = f'cp test'
epoch_iterator = tqdm(dataloader)
allids, allembeddings = [], []
with torch.no_grad():
for iter, batch in enumerate(epoch_iterator, start=1):
item_ids, inputs = batch
allids.extend(item_ids)
with torch.cuda.amp.autocast():
inputs = {key: value.to(device) for key, value in inputs.items()}
_, batch_embeddings = model.encode(inputs).values()
batch_embeddings = batch_embeddings.cpu().float()
allembeddings.append(batch_embeddings)
allembeddings = torch.cat(allembeddings, dim=0).numpy()
return allids, allembeddings
if __name__ == '__main__':
cuda_condition = torch.cuda.is_available() and args.with_cuda
device = torch.device("cuda:0" if cuda_condition else "cpu")
model, input_processor = load_model(args)
model.to(device)
model = model.half()
dataset = PolyvoreDataset(args.data_dir, args, input_processor)
dataloader = DataLoader(
dataset=dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
model.eval()
with torch.no_grad():
allids, allembeddings = generate(model, dataloader=dataloader, device=device)
save_file = os.path.join('./', 'embeddings.pkl')
with open(save_file, mode="wb") as f:
pickle.dump((allids, allembeddings), f)
print(f"Total passages processed {len(allids)}. Written to {save_file}.")