Skip to content

Commit ecc7597

Browse files
author
宣源
committed
qwen support multi gpu
1 parent 5ffaab0 commit ecc7597

File tree

1 file changed

+176
-17
lines changed

1 file changed

+176
-17
lines changed

models/qwenimage.py

Lines changed: 176 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
from nunchaku.models.linear import AWQW4A16Linear, SVDQW4A4Linear
2626
from nunchaku.models.utils import CPUOffloadManager
2727
from nunchaku.ops.fused import fused_gelu_mlp
28+
from dist_utils import all_all_async, args, tensor_chunk, all_gather, all_all, has_nvlink, all_gather_async
29+
import logging
2830

2931
from ..mixins.model import NunchakuModelMixin
3032

@@ -238,6 +240,12 @@ def __init__(
238240
self.to_add_out = SVDQW4A4Linear(
239241
self.inner_dim, self.out_context_dim, bias=out_bias, torch_dtype=dtype, device=device, **kwargs
240242
)
243+
if args.world_size>1 and (not has_nvlink):
244+
self.overlap_num = self.heads // (2 * args.world_size)
245+
if self.overlap_num > 1:
246+
if args.rank == 0:
247+
logging.info(f"no nvlink and self.overlap_num={self.overlap_num}, using compute and communication overlap")
248+
self.forward = self.forward_overlap
241249

242250
def forward(
243251
self,
@@ -246,6 +254,7 @@ def forward(
246254
encoder_hidden_states_mask: torch.FloatTensor = None,
247255
attention_mask: Optional[torch.FloatTensor] = None,
248256
image_rotary_emb: Optional[torch.Tensor] = None,
257+
transformer_options={},
249258
) -> Tuple[torch.Tensor, torch.Tensor]:
250259
"""
251260
Forward pass for double-stream attention.
@@ -274,46 +283,183 @@ def forward(
274283

275284
img_qkv = self.to_qkv(hidden_states)
276285
img_query, img_key, img_value = img_qkv.chunk(3, dim=-1)
286+
img_query = img_query.unflatten(-1, (self.heads, -1))
287+
img_key = img_key.unflatten(-1, (self.heads, -1))
288+
img_value = img_value.unflatten(-1, (self.heads, -1))
289+
b, _, _, d = img_value.shape
290+
if args.world_size>1:
291+
sp_lens = transformer_options.get('sp_len')
292+
v_data_list = tensor_chunk(img_value, -2)
293+
val_datashapes = [[b, sp_lens[rank_i], v_data_list[rank_i].size(-2), d] for rank_i in range(args.world_size)]
294+
output_val_list, val_async_worker, val_datashapes = all_all_async(img_value, -2, val_datashapes, v_data_list)
295+
img_query = self.norm_q(img_query)
296+
img_key = self.norm_k(img_key)
297+
output_q_list, q_async_worker, _ = all_all_async(img_query, -2, val_datashapes)
298+
output_k_list, k_async_worker, _ = all_all_async(img_key, -2, val_datashapes)
299+
else:
300+
img_query = self.norm_q(img_query)
301+
img_key = self.norm_k(img_key)
277302

278303
# Compute QKV for text stream (context projections)
279304
txt_qkv = self.add_qkv_proj(encoder_hidden_states)
280305
txt_query, txt_key, txt_value = txt_qkv.chunk(3, dim=-1)
281306

282-
img_query = img_query.unflatten(-1, (self.heads, -1))
283-
img_key = img_key.unflatten(-1, (self.heads, -1))
284-
img_value = img_value.unflatten(-1, (self.heads, -1))
285-
286307
txt_query = txt_query.unflatten(-1, (self.heads, -1))
287308
txt_key = txt_key.unflatten(-1, (self.heads, -1))
288309
txt_value = txt_value.unflatten(-1, (self.heads, -1))
289310

290-
img_query = self.norm_q(img_query)
291-
img_key = self.norm_k(img_key)
292311
txt_query = self.norm_added_q(txt_query)
293312
txt_key = self.norm_added_k(txt_key)
294313

295-
# Concatenate image and text streams for joint attention
296-
joint_query = torch.cat([txt_query, img_query], dim=1)
297-
joint_key = torch.cat([txt_key, img_key], dim=1)
298-
joint_value = torch.cat([txt_value, img_value], dim=1)
314+
if args.world_size>1:
315+
txt_data_list = tensor_chunk(txt_value, 2)
316+
txt_value = txt_data_list[args.rank]
317+
txt_query = tensor_chunk(txt_query, 2)[args.rank]
318+
txt_key = tensor_chunk(txt_key, 2)[args.rank]
319+
320+
val_async_worker.wait()
321+
img_value = torch.cat(output_val_list, dim=-3).contiguous()
322+
joint_value = torch.cat([txt_value, img_value], dim=1)
323+
joint_value = joint_value.flatten(start_dim=2)
324+
325+
q_async_worker.wait()
326+
img_query = torch.cat(output_q_list, dim=-3).contiguous()
327+
joint_query = torch.cat([txt_query, img_query], dim=1)
328+
joint_query = apply_rotary_emb(joint_query, image_rotary_emb)
329+
heads = joint_query.size(-2)
330+
joint_query = joint_query.flatten(start_dim=2)
331+
332+
k_async_worker.wait()
333+
img_key = torch.cat(output_k_list, dim=-3).contiguous()
334+
joint_key = torch.cat([txt_key, img_key], dim=1)
335+
joint_key = apply_rotary_emb(joint_key, image_rotary_emb)
336+
joint_key = joint_key.flatten(start_dim=2)
337+
338+
else:
339+
joint_value = torch.cat([txt_value, img_value], dim=1)
340+
joint_value = joint_value.flatten(start_dim=2)
299341

300-
# Apply rotary embeddings
301-
joint_query = apply_rotary_emb(joint_query, image_rotary_emb)
302-
joint_key = apply_rotary_emb(joint_key, image_rotary_emb)
342+
joint_query = torch.cat([txt_query, img_query], dim=1)
343+
joint_query = apply_rotary_emb(joint_query, image_rotary_emb)
344+
heads = joint_query.size(-2)
345+
joint_query = joint_query.flatten(start_dim=2)
303346

304-
joint_query = joint_query.flatten(start_dim=2)
305-
joint_key = joint_key.flatten(start_dim=2)
306-
joint_value = joint_value.flatten(start_dim=2)
347+
joint_key = torch.cat([txt_key, img_key], dim=1)
348+
joint_key = apply_rotary_emb(joint_key, image_rotary_emb)
349+
joint_key = joint_key.flatten(start_dim=2)
307350

308351
# Compute joint attention
309352
joint_hidden_states = optimized_attention_masked(
310-
joint_query, joint_key, joint_value, self.heads, attention_mask
353+
joint_query, joint_key, joint_value, heads, attention_mask
311354
)
312355

313356
# Split results back to separate streams
314357
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
315358
img_attn_output = joint_hidden_states[:, seq_txt:, :]
316359

360+
if args.world_size>1:
361+
data_shapes = [[b, sp_lens[args.rank], val_datashapes[rank_i][2] * val_datashapes[rank_i][3]] for rank_i in range(args.world_size)]
362+
img_attn_output = all_all(img_attn_output, -2, -1, data_shapes, sp_lens)
363+
txt_attn_output = all_gather([_.flatten(start_dim=2) for _ in txt_data_list], txt_attn_output, 2)
364+
365+
img_attn_output = self.to_out[0](img_attn_output)
366+
img_attn_output = self.to_out[1](img_attn_output)
367+
txt_attn_output = self.to_add_out(txt_attn_output)
368+
369+
return img_attn_output, txt_attn_output
370+
371+
def forward_overlap(
372+
self,
373+
hidden_states: torch.FloatTensor, # Image stream
374+
encoder_hidden_states: torch.FloatTensor = None, # Text stream
375+
encoder_hidden_states_mask: torch.FloatTensor = None,
376+
attention_mask: Optional[torch.FloatTensor] = None,
377+
image_rotary_emb: Optional[torch.Tensor] = None,
378+
transformer_options={},
379+
) -> Tuple[torch.Tensor, torch.Tensor]:
380+
seq_txt = encoder_hidden_states.shape[1]
381+
382+
img_qkv = self.to_qkv(hidden_states)
383+
img_query, img_key, img_value = img_qkv.chunk(3, dim=-1)
384+
img_query = img_query.unflatten(-1, (self.heads, -1))
385+
img_key = img_key.unflatten(-1, (self.heads, -1))
386+
img_value = img_value.unflatten(-1, (self.heads, -1))
387+
img_query = self.norm_q(img_query)
388+
img_key = self.norm_k(img_key)
389+
b, _, _, d = img_value.shape
390+
391+
sp_lens = transformer_options.get('sp_len')
392+
393+
img_qkv = torch.cat([img_query, img_key, img_value], dim=0)
394+
img_qkv_lists = img_qkv.chunk(self.overlap_num, 2)
395+
output_qkv_workers = []
396+
for qkv_data in img_qkv_lists:
397+
qkv_data_list = tensor_chunk(qkv_data, -2)
398+
qkv_datashapes = [[3 * b, sp_lens[rank_i], qkv_data_list[rank_i].size(-2), d] for rank_i in range(args.world_size)]
399+
output_qkv_list, qkv_async_worker, _ = all_all_async(img_qkv, -2, qkv_datashapes, qkv_data_list)
400+
output_qkv_workers.append([output_qkv_list, qkv_async_worker])
401+
402+
txt_qkv = self.add_qkv_proj(encoder_hidden_states)
403+
txt_query, txt_key, txt_value = txt_qkv.chunk(3, dim=-1)
404+
405+
txt_query = txt_query.unflatten(-1, (self.heads, -1))
406+
txt_key = txt_key.unflatten(-1, (self.heads, -1))
407+
txt_value = txt_value.unflatten(-1, (self.heads, -1))
408+
409+
txt_query = self.norm_added_q(txt_query)
410+
txt_key = self.norm_added_k(txt_key)
411+
412+
txt_key_lists = txt_key.chunk(self.overlap_num, 2)
413+
txt_query_lists = txt_query.chunk(self.overlap_num, 2)
414+
txt_value_lists = txt_value.chunk(self.overlap_num, 2)
415+
img_attn_output_works_lists = []
416+
txt_attn_output_works_lists = []
417+
418+
for data_idx, output_qkv_worker in enumerate(output_qkv_workers):
419+
qkv_list, qkv_worker = output_qkv_worker
420+
421+
txt_key_list = tensor_chunk(txt_key_lists[data_idx], 2)
422+
_txt_key = txt_key_list[args.rank]
423+
_txt_query = tensor_chunk(txt_query_lists[data_idx], 2)[args.rank]
424+
_txt_value = tensor_chunk(txt_value_lists[data_idx], 2)[args.rank]
425+
426+
qkv_worker.wait()
427+
q, k, v = torch.cat(qkv_list, dim=1).chunk(3, 0)
428+
429+
joint_value = torch.cat([_txt_value, v], dim=1)
430+
joint_value = joint_value.flatten(start_dim=2)
431+
432+
joint_query = torch.cat([_txt_query, q], dim=1)
433+
joint_query = apply_rotary_emb(joint_query, image_rotary_emb)
434+
heads = joint_query.size(-2)
435+
joint_query = joint_query.flatten(start_dim=2)
436+
437+
joint_key = torch.cat([_txt_key, k], dim=1)
438+
joint_key = apply_rotary_emb(joint_key, image_rotary_emb)
439+
joint_key = joint_key.flatten(start_dim=2)
440+
441+
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, heads, attention_mask, transformer_options=transformer_options)
442+
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
443+
img_attn_output = joint_hidden_states[:, seq_txt:, :]
444+
data_shapes = [[b, sp_lens[args.rank], txt_key_list[rank_i].size(2) * d] for rank_i in range(args.world_size)]
445+
446+
img_attn_output_list, img_attn_output_worker, _ = all_all_async(img_attn_output, 1, data_shapes, tensor_chunk(img_attn_output, 1))
447+
img_attn_output_works_lists.append([img_attn_output_list, img_attn_output_worker])
448+
txt_attn_output_list, txt_attn_output_worker = all_gather_async([_.flatten(start_dim=2) for _ in txt_key_list], txt_attn_output, 2)
449+
txt_attn_output_works_lists.append([txt_attn_output_list, txt_attn_output_worker])
450+
451+
img_outs = []
452+
txt_outs = []
453+
for img_idx, (img_out, img_worker) in enumerate(img_attn_output_works_lists):
454+
img_worker.wait()
455+
img_outs.append(torch.cat(img_out, dim=2))
456+
txt_out, txt_worker = txt_attn_output_works_lists[img_idx]
457+
txt_worker.wait()
458+
txt_outs.append(torch.cat(txt_out, dim=2))
459+
460+
img_attn_output = torch.cat(img_outs, dim=2)
461+
txt_attn_output = torch.cat(txt_outs, dim=2)
462+
317463
img_attn_output = self.to_out[0](img_attn_output)
318464
img_attn_output = self.to_out[1](img_attn_output)
319465
txt_attn_output = self.to_add_out(txt_attn_output)
@@ -701,6 +847,11 @@ def _forward(
701847
.reshape(1, -1, 1)
702848
.repeat(x.shape[0], 1, 3)
703849
)
850+
if args.world_size>1:
851+
img_lists = tensor_chunk(hidden_states, -2)
852+
hidden_states = img_lists[args.rank]
853+
sp_len = [img_lists[idx].size(-2) for idx in range(args.world_size)]
854+
transformer_options['sp_len'] = sp_len
704855
ids = torch.cat((txt_ids, img_ids), dim=1)
705856
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
706857
del ids, txt_ids, img_ids
@@ -740,6 +891,7 @@ def block_wrap(args):
740891
encoder_hidden_states_mask=encoder_hidden_states_mask,
741892
temb=args["vec"],
742893
image_rotary_emb=args["pe"],
894+
transformer_options=args["transformer_options"]
743895
)
744896
return out
745897

@@ -756,6 +908,7 @@ def block_wrap(args):
756908
encoder_hidden_states_mask=encoder_hidden_states_mask,
757909
temb=temb,
758910
image_rotary_emb=image_rotary_emb,
911+
transformer_options=transformer_options,
759912
)
760913
# ControlNet helpers(device/dtype-safe residual adds)
761914
_control = (
@@ -790,6 +943,12 @@ def block_wrap(args):
790943
hidden_states = self.norm_out(hidden_states, temb)
791944
hidden_states = self.proj_out(hidden_states)
792945

946+
if args.world_size>1:
947+
bs, _, ndim = hidden_states.shape
948+
# datas = [hidden_states.new_empty((bs, img_lists[rank_i].size(1), ndim)) for rank_i in range(args.world_size)]
949+
datas = [hidden_states.new_empty((bs, sp_len[rank_i], ndim)) for rank_i in range(args.world_size)]
950+
hidden_states = all_gather(datas, hidden_states, -2)
951+
793952
hidden_states = hidden_states[:, :num_embeds].view(
794953
orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2
795954
)

0 commit comments

Comments
 (0)