-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcoca.py
More file actions
69 lines (52 loc) · 2.04 KB
/
coca.py
File metadata and controls
69 lines (52 loc) · 2.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch
# import vision transformer
from vit_pytorch.simple_vit_with_patch_dropout import SimpleViT
from vit_pytorch.extractor import Extractor
vit = SimpleViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
patch_dropout = 0.5 # https://arxiv.org/abs/2212.00794
)
vit = Extractor(vit, return_embeddings_only = True, detach = False)
# extractor will enable it so the vision transformer returns its embeddings
# import CoCa and instantiate it
from coca_pytorch.coca_pytorch import CoCa
coca = CoCa(
dim = 512, # model dimension
img_encoder = vit, # vision transformer - image encoder, returning image embeddings as (batch, seq, dim)
image_dim = 1024, # image embedding dimension, if not the same as model dimensions
num_tokens = 20000, # number of text tokens
unimodal_depth = 6, # depth of the unimodal transformer
multimodal_depth = 6, # depth of the multimodal transformer
dim_head = 64, # dimension per attention head
heads = 8, # number of attention heads
caption_loss_weight = 1., # weight on the autoregressive caption loss
contrastive_loss_weight = 1., # weight on the contrastive loss between image and text CLS embeddings
).cuda()
# mock text and images
text = torch.randint(0, 20000, (4, 512)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()
# train by giving CoCa your text and images with `return_loss = True`
loss = coca(
text = text,
images = images,
return_loss = True # set this to True to get the full caption + contrastive loss
)
loss.backward()
# do the above for as much text and images...
# then you can get the caption logits as so
logits = coca(
text = text,
images = images
) # (4, 512, 20000)
# and the CLIP-like text and image embeddings as
text_embeds, image_embeds = coca(
text = text,
images = images,
return_embeddings = True
) # (4, 512), (4, 512)