Skip to content

Commit 5e21bbf

Browse files
Peng Chenfacebook-github-bot
authored andcommitted
add blip2 loss under torchmultimodal/modules/losses (#485)
Summary: Pull Request resolved: #485 as title Differential Revision: D50148648 fbshipit-source-id: 04f3fbb096c8dde167c792f5496bd87cd295c8c0
1 parent 6ffde67 commit 5e21bbf

File tree

2 files changed

+691
-0
lines changed

2 files changed

+691
-0
lines changed
Lines changed: 331 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,331 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from itertools import chain
8+
9+
import pytest
10+
import torch
11+
from tests.test_utils import (
12+
assert_expected,
13+
gpu_test,
14+
init_distributed_on_file,
15+
init_weights_with_constant,
16+
with_temp_files,
17+
)
18+
from torch import distributed as dist, multiprocessing as mp, nn, optim
19+
from torchmultimodal.models.blip2.blip2 import BLIP2, Blip2Output
20+
from torchmultimodal.models.blip2.qformer_model import QformerForCLM
21+
from torchmultimodal.modules.encoders.vision_transformer import VisionTransformer
22+
from torchmultimodal.modules.layers.patch_embedding import PatchEmbeddings
23+
from torchmultimodal.modules.layers.transformer import TransformerEncoder
24+
from torchmultimodal.modules.losses.blip2_losses import Blip2Phase1Loss
25+
26+
27+
@pytest.fixture
28+
def dim_q():
29+
return 4
30+
31+
32+
@pytest.fixture
33+
def dim_kv():
34+
return 2
35+
36+
37+
@pytest.fixture
38+
def dim_feedforward():
39+
return 6
40+
41+
42+
@pytest.fixture
43+
def num_hidden_layers():
44+
return 2
45+
46+
47+
@pytest.fixture
48+
def num_heads():
49+
return 2
50+
51+
52+
@pytest.fixture
53+
def vocab_size():
54+
return 20
55+
56+
57+
@pytest.fixture
58+
def vit():
59+
embedding = PatchEmbeddings(image_size=2, patch_size=1, hidden_size=2)
60+
encoder = TransformerEncoder(
61+
n_layer=1,
62+
d_model=2,
63+
n_head=1,
64+
dim_feedforward=1,
65+
activation=nn.GELU,
66+
norm_first=True,
67+
final_layer_norm_eps=1e-5,
68+
)
69+
image_encoder = VisionTransformer(
70+
embeddings=embedding,
71+
encoder=encoder,
72+
)
73+
init_weights_with_constant(image_encoder)
74+
image_encoder.eval()
75+
return image_encoder
76+
77+
78+
class TestBLIP2Stage1Loss:
79+
@pytest.fixture
80+
def images(self):
81+
return torch.ones(4, 3, 2, 2)
82+
83+
@pytest.fixture
84+
def input_ids(self):
85+
return torch.ones(4, 4).long()
86+
87+
@pytest.fixture
88+
def all_attn_mask(self):
89+
return torch.ones([4, 4])
90+
91+
@pytest.fixture
92+
def global_batch_size(self):
93+
return 4
94+
95+
@pytest.fixture
96+
def qformer_model_for_clm(
97+
self,
98+
dim_q,
99+
dim_kv,
100+
dim_feedforward,
101+
num_hidden_layers,
102+
num_heads,
103+
vocab_size,
104+
):
105+
qformer_for_clm = QformerForCLM(
106+
dim_q=dim_q,
107+
dim_kv=dim_kv,
108+
dim_feedforward=dim_feedforward,
109+
num_heads=num_heads,
110+
attn_dropout=0.0,
111+
dropout=0.0,
112+
num_hidden_layers=num_hidden_layers,
113+
max_position_embeddings=512,
114+
vocab_size=vocab_size,
115+
)
116+
return qformer_for_clm
117+
118+
@pytest.fixture
119+
def blip2_output(self):
120+
return Blip2Output(
121+
image_embeddings=torch.ones([4, 5, 2]),
122+
image_features=torch.ones([4, 32, 4]) * 0.5,
123+
image_qformer_output=torch.ones([4, 32, 4]) * 0.5,
124+
text_features=torch.ones([4, 4]) * 0.5,
125+
prediction_scores=torch.ones([4, 4, 20]) * 5,
126+
)
127+
128+
@pytest.fixture
129+
def blip2(self, dim_q, dim_kv, qformer_model_for_clm, vit):
130+
blip2 = BLIP2(
131+
dim_q=dim_q,
132+
image_encoder_embedding_dim=dim_kv,
133+
qformer=qformer_model_for_clm,
134+
vision_encoder=vit,
135+
embedding_dim=4,
136+
decoder_bos_token_id=19,
137+
)
138+
init_weights_with_constant(blip2)
139+
blip2.eval()
140+
return blip2
141+
142+
def test_local_loss(self, all_attn_mask, blip2_output, blip2, dim_q, input_ids):
143+
blip2_loss = Blip2Phase1Loss(dim_q=dim_q)
144+
init_weights_with_constant(blip2_loss)
145+
local_loss = blip2_loss(
146+
model_output=blip2_output,
147+
blip2=blip2,
148+
input_ids=input_ids,
149+
attention_mask=all_attn_mask,
150+
)
151+
assert_expected(local_loss.total_loss.item(), 5.07517, rtol=0, atol=1e-4)
152+
153+
def test_local_itc_only_loss(
154+
self, all_attn_mask, blip2_output, blip2, dim_q, input_ids
155+
):
156+
blip2_loss = Blip2Phase1Loss(dim_q=dim_q, enable_itm=False, enable_itg=False)
157+
init_weights_with_constant(blip2_loss)
158+
local_loss = blip2_loss(
159+
model_output=blip2_output,
160+
blip2=blip2,
161+
input_ids=input_ids,
162+
attention_mask=all_attn_mask,
163+
)
164+
assert_expected(local_loss.total_loss.item(), 1.38629, rtol=0, atol=1e-4)
165+
166+
def test_local_itm_only_loss(
167+
self, all_attn_mask, blip2_output, blip2, dim_q, input_ids
168+
):
169+
blip2_loss = Blip2Phase1Loss(dim_q=dim_q, enable_itc=False, enable_itg=False)
170+
init_weights_with_constant(blip2_loss)
171+
local_loss = blip2_loss(
172+
model_output=blip2_output,
173+
blip2=blip2,
174+
input_ids=input_ids,
175+
attention_mask=all_attn_mask,
176+
)
177+
assert_expected(local_loss.total_loss.item(), 0.69315, rtol=0, atol=1e-4)
178+
179+
def test_local_itg_only_loss(
180+
self, all_attn_mask, blip2_output, blip2, dim_q, input_ids
181+
):
182+
blip2_loss = Blip2Phase1Loss(dim_q=dim_q, enable_itc=False, enable_itm=False)
183+
init_weights_with_constant(blip2_loss)
184+
local_loss = blip2_loss(
185+
model_output=blip2_output,
186+
blip2=blip2,
187+
input_ids=input_ids,
188+
attention_mask=all_attn_mask,
189+
)
190+
assert_expected(local_loss.total_loss.item(), 2.9957, rtol=0, atol=1e-4)
191+
192+
def test_invalid_loss_input(self):
193+
with pytest.raises(ValueError):
194+
Blip2Phase1Loss(
195+
dim_q=dim_q, enable_itc=False, enable_itm=False, enable_itg=False
196+
)
197+
198+
@staticmethod
199+
def _model_worker(
200+
gpu_id: int,
201+
sync_file: str,
202+
world_size: int,
203+
global_batch_size: int,
204+
all_images: torch.Tensor,
205+
all_input_ids: torch.Tensor,
206+
all_attn_mask: torch.Tensor,
207+
blip2_output: Blip2Output,
208+
blip2: nn.Module,
209+
dim_q=dim_q,
210+
):
211+
init_distributed_on_file(
212+
world_size=world_size, gpu_id=gpu_id, sync_file=sync_file
213+
)
214+
assert global_batch_size % world_size == 0
215+
local_batch_size = global_batch_size // world_size
216+
all_attn_mask = torch.ones([4, 4])
217+
218+
# Split inputs across GPUs
219+
local_images = torch.split(all_images, local_batch_size)[gpu_id].cuda(gpu_id)
220+
local_input_ids = torch.split(all_input_ids, local_batch_size)[gpu_id].cuda(
221+
gpu_id
222+
)
223+
local_attn_mask = torch.split(all_attn_mask, local_batch_size)[gpu_id].cuda(
224+
gpu_id
225+
)
226+
assert blip2_output.text_features is not None
227+
assert blip2_output.prediction_scores is not None
228+
local_blip2_output = Blip2Output(
229+
image_embeddings=torch.split(
230+
blip2_output.image_embeddings, local_batch_size
231+
)[gpu_id].cuda(gpu_id),
232+
image_features=torch.split(blip2_output.image_features, local_batch_size)[
233+
gpu_id
234+
].cuda(gpu_id),
235+
image_qformer_output=torch.split(
236+
blip2_output.image_qformer_output, local_batch_size
237+
)[gpu_id].cuda(gpu_id),
238+
text_features=torch.split(blip2_output.text_features, local_batch_size)[
239+
gpu_id
240+
].cuda(gpu_id),
241+
prediction_scores=torch.split(
242+
blip2_output.prediction_scores, local_batch_size
243+
)[gpu_id].cuda(gpu_id),
244+
)
245+
246+
blip2 = blip2.cuda(gpu_id)
247+
loss_fn = Blip2Phase1Loss(dim_q=dim_q)
248+
init_weights_with_constant(loss_fn)
249+
loss_fn = loss_fn.cuda(gpu_id)
250+
251+
all_params = chain(blip2.parameters(), loss_fn.parameters())
252+
253+
optimizer = optim.SGD(all_params, lr=1e-4)
254+
255+
# Forward pass
256+
loss = loss_fn(
257+
model_output=local_blip2_output,
258+
blip2=blip2,
259+
images=local_images,
260+
input_ids=local_input_ids,
261+
attention_mask=local_attn_mask,
262+
).total_loss
263+
264+
# Compute gradients
265+
optimizer.zero_grad()
266+
loss.backward()
267+
268+
# Gather gradients from all devices
269+
def gather_grads(x: torch.Tensor) -> torch.Tensor:
270+
grads = [torch.zeros_like(x).cuda(gpu_id) for i in range(world_size)]
271+
dist.all_gather(grads, x)
272+
grad = torch.stack(grads).mean()
273+
return grad
274+
275+
# Gather losses from all devices
276+
gathered_loss = gather_grads(torch.Tensor([loss]).cuda(gpu_id))
277+
assert_expected(gathered_loss.item(), 5.07517, rtol=0, atol=1e-4)
278+
279+
@gpu_test(gpu_count=1)
280+
def test_single_gpu_loss(
281+
self,
282+
global_batch_size,
283+
input_ids,
284+
blip2_output,
285+
blip2,
286+
attn_mask,
287+
dim_q,
288+
):
289+
with with_temp_files(count=1) as sync_file:
290+
world_size = 1
291+
mp.spawn(
292+
TestBLIP2Stage1Loss._model_worker,
293+
(
294+
sync_file,
295+
world_size,
296+
global_batch_size,
297+
input_ids,
298+
attn_mask,
299+
blip2_output,
300+
blip2,
301+
dim_q,
302+
),
303+
nprocs=world_size,
304+
)
305+
306+
@gpu_test(gpu_count=2)
307+
def test_multi_gpu_loss(
308+
self,
309+
global_batch_size,
310+
input_ids,
311+
blip2_output,
312+
blip2,
313+
attn_mask,
314+
dim_q,
315+
):
316+
with with_temp_files(count=1) as sync_file:
317+
world_size = 2
318+
mp.spawn(
319+
TestBLIP2Stage1Loss._model_worker,
320+
(
321+
sync_file,
322+
world_size,
323+
global_batch_size,
324+
input_ids,
325+
attn_mask,
326+
blip2_output,
327+
blip2,
328+
dim_q,
329+
),
330+
nprocs=world_size,
331+
)

0 commit comments

Comments
 (0)