2525from nunchaku .models .linear import AWQW4A16Linear , SVDQW4A4Linear
2626from nunchaku .models .utils import CPUOffloadManager
2727from nunchaku .ops .fused import fused_gelu_mlp
28+ from dist_utils import all_all_async , args , tensor_chunk , all_gather , all_all , has_nvlink , all_gather_async
29+ import logging
2830
2931from ..mixins .model import NunchakuModelMixin
3032
@@ -238,6 +240,12 @@ def __init__(
238240 self .to_add_out = SVDQW4A4Linear (
239241 self .inner_dim , self .out_context_dim , bias = out_bias , torch_dtype = dtype , device = device , ** kwargs
240242 )
243+ if args .world_size > 1 and (not has_nvlink ):
244+ self .overlap_num = self .heads // (2 * args .world_size )
245+ if self .overlap_num > 1 :
246+ if args .rank == 0 :
247+ logging .info (f"no nvlink and self.overlap_num={ self .overlap_num } , using compute and communication overlap" )
248+ self .forward = self .forward_overlap
241249
242250 def forward (
243251 self ,
@@ -246,6 +254,7 @@ def forward(
246254 encoder_hidden_states_mask : torch .FloatTensor = None ,
247255 attention_mask : Optional [torch .FloatTensor ] = None ,
248256 image_rotary_emb : Optional [torch .Tensor ] = None ,
257+ transformer_options = {},
249258 ) -> Tuple [torch .Tensor , torch .Tensor ]:
250259 """
251260 Forward pass for double-stream attention.
@@ -274,46 +283,183 @@ def forward(
274283
275284 img_qkv = self .to_qkv (hidden_states )
276285 img_query , img_key , img_value = img_qkv .chunk (3 , dim = - 1 )
286+ img_query = img_query .unflatten (- 1 , (self .heads , - 1 ))
287+ img_key = img_key .unflatten (- 1 , (self .heads , - 1 ))
288+ img_value = img_value .unflatten (- 1 , (self .heads , - 1 ))
289+ b , _ , _ , d = img_value .shape
290+ if args .world_size > 1 :
291+ sp_lens = transformer_options .get ('sp_len' )
292+ v_data_list = tensor_chunk (img_value , - 2 )
293+ val_datashapes = [[b , sp_lens [rank_i ], v_data_list [rank_i ].size (- 2 ), d ] for rank_i in range (args .world_size )]
294+ output_val_list , val_async_worker , val_datashapes = all_all_async (img_value , - 2 , val_datashapes , v_data_list )
295+ img_query = self .norm_q (img_query )
296+ img_key = self .norm_k (img_key )
297+ output_q_list , q_async_worker , _ = all_all_async (img_query , - 2 , val_datashapes )
298+ output_k_list , k_async_worker , _ = all_all_async (img_key , - 2 , val_datashapes )
299+ else :
300+ img_query = self .norm_q (img_query )
301+ img_key = self .norm_k (img_key )
277302
278303 # Compute QKV for text stream (context projections)
279304 txt_qkv = self .add_qkv_proj (encoder_hidden_states )
280305 txt_query , txt_key , txt_value = txt_qkv .chunk (3 , dim = - 1 )
281306
282- img_query = img_query .unflatten (- 1 , (self .heads , - 1 ))
283- img_key = img_key .unflatten (- 1 , (self .heads , - 1 ))
284- img_value = img_value .unflatten (- 1 , (self .heads , - 1 ))
285-
286307 txt_query = txt_query .unflatten (- 1 , (self .heads , - 1 ))
287308 txt_key = txt_key .unflatten (- 1 , (self .heads , - 1 ))
288309 txt_value = txt_value .unflatten (- 1 , (self .heads , - 1 ))
289310
290- img_query = self .norm_q (img_query )
291- img_key = self .norm_k (img_key )
292311 txt_query = self .norm_added_q (txt_query )
293312 txt_key = self .norm_added_k (txt_key )
294313
295- # Concatenate image and text streams for joint attention
296- joint_query = torch .cat ([txt_query , img_query ], dim = 1 )
297- joint_key = torch .cat ([txt_key , img_key ], dim = 1 )
298- joint_value = torch .cat ([txt_value , img_value ], dim = 1 )
314+ if args .world_size > 1 :
315+ txt_data_list = tensor_chunk (txt_value , 2 )
316+ txt_value = txt_data_list [args .rank ]
317+ txt_query = tensor_chunk (txt_query , 2 )[args .rank ]
318+ txt_key = tensor_chunk (txt_key , 2 )[args .rank ]
319+
320+ val_async_worker .wait ()
321+ img_value = torch .cat (output_val_list , dim = - 3 ).contiguous ()
322+ joint_value = torch .cat ([txt_value , img_value ], dim = 1 )
323+ joint_value = joint_value .flatten (start_dim = 2 )
324+
325+ q_async_worker .wait ()
326+ img_query = torch .cat (output_q_list , dim = - 3 ).contiguous ()
327+ joint_query = torch .cat ([txt_query , img_query ], dim = 1 )
328+ joint_query = apply_rotary_emb (joint_query , image_rotary_emb )
329+ heads = joint_query .size (- 2 )
330+ joint_query = joint_query .flatten (start_dim = 2 )
331+
332+ k_async_worker .wait ()
333+ img_key = torch .cat (output_k_list , dim = - 3 ).contiguous ()
334+ joint_key = torch .cat ([txt_key , img_key ], dim = 1 )
335+ joint_key = apply_rotary_emb (joint_key , image_rotary_emb )
336+ joint_key = joint_key .flatten (start_dim = 2 )
337+
338+ else :
339+ joint_value = torch .cat ([txt_value , img_value ], dim = 1 )
340+ joint_value = joint_value .flatten (start_dim = 2 )
299341
300- # Apply rotary embeddings
301- joint_query = apply_rotary_emb (joint_query , image_rotary_emb )
302- joint_key = apply_rotary_emb (joint_key , image_rotary_emb )
342+ joint_query = torch .cat ([txt_query , img_query ], dim = 1 )
343+ joint_query = apply_rotary_emb (joint_query , image_rotary_emb )
344+ heads = joint_query .size (- 2 )
345+ joint_query = joint_query .flatten (start_dim = 2 )
303346
304- joint_query = joint_query . flatten ( start_dim = 2 )
305- joint_key = joint_key . flatten ( start_dim = 2 )
306- joint_value = joint_value .flatten (start_dim = 2 )
347+ joint_key = torch . cat ([ txt_key , img_key ], dim = 1 )
348+ joint_key = apply_rotary_emb ( joint_key , image_rotary_emb )
349+ joint_key = joint_key .flatten (start_dim = 2 )
307350
308351 # Compute joint attention
309352 joint_hidden_states = optimized_attention_masked (
310- joint_query , joint_key , joint_value , self . heads , attention_mask
353+ joint_query , joint_key , joint_value , heads , attention_mask
311354 )
312355
313356 # Split results back to separate streams
314357 txt_attn_output = joint_hidden_states [:, :seq_txt , :]
315358 img_attn_output = joint_hidden_states [:, seq_txt :, :]
316359
360+ if args .world_size > 1 :
361+ data_shapes = [[b , sp_lens [args .rank ], val_datashapes [rank_i ][2 ] * val_datashapes [rank_i ][3 ]] for rank_i in range (args .world_size )]
362+ img_attn_output = all_all (img_attn_output , - 2 , - 1 , data_shapes , sp_lens )
363+ txt_attn_output = all_gather ([_ .flatten (start_dim = 2 ) for _ in txt_data_list ], txt_attn_output , 2 )
364+
365+ img_attn_output = self .to_out [0 ](img_attn_output )
366+ img_attn_output = self .to_out [1 ](img_attn_output )
367+ txt_attn_output = self .to_add_out (txt_attn_output )
368+
369+ return img_attn_output , txt_attn_output
370+
371+ def forward_overlap (
372+ self ,
373+ hidden_states : torch .FloatTensor , # Image stream
374+ encoder_hidden_states : torch .FloatTensor = None , # Text stream
375+ encoder_hidden_states_mask : torch .FloatTensor = None ,
376+ attention_mask : Optional [torch .FloatTensor ] = None ,
377+ image_rotary_emb : Optional [torch .Tensor ] = None ,
378+ transformer_options = {},
379+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
380+ seq_txt = encoder_hidden_states .shape [1 ]
381+
382+ img_qkv = self .to_qkv (hidden_states )
383+ img_query , img_key , img_value = img_qkv .chunk (3 , dim = - 1 )
384+ img_query = img_query .unflatten (- 1 , (self .heads , - 1 ))
385+ img_key = img_key .unflatten (- 1 , (self .heads , - 1 ))
386+ img_value = img_value .unflatten (- 1 , (self .heads , - 1 ))
387+ img_query = self .norm_q (img_query )
388+ img_key = self .norm_k (img_key )
389+ b , _ , _ , d = img_value .shape
390+
391+ sp_lens = transformer_options .get ('sp_len' )
392+
393+ img_qkv = torch .cat ([img_query , img_key , img_value ], dim = 0 )
394+ img_qkv_lists = img_qkv .chunk (self .overlap_num , 2 )
395+ output_qkv_workers = []
396+ for qkv_data in img_qkv_lists :
397+ qkv_data_list = tensor_chunk (qkv_data , - 2 )
398+ qkv_datashapes = [[3 * b , sp_lens [rank_i ], qkv_data_list [rank_i ].size (- 2 ), d ] for rank_i in range (args .world_size )]
399+ output_qkv_list , qkv_async_worker , _ = all_all_async (img_qkv , - 2 , qkv_datashapes , qkv_data_list )
400+ output_qkv_workers .append ([output_qkv_list , qkv_async_worker ])
401+
402+ txt_qkv = self .add_qkv_proj (encoder_hidden_states )
403+ txt_query , txt_key , txt_value = txt_qkv .chunk (3 , dim = - 1 )
404+
405+ txt_query = txt_query .unflatten (- 1 , (self .heads , - 1 ))
406+ txt_key = txt_key .unflatten (- 1 , (self .heads , - 1 ))
407+ txt_value = txt_value .unflatten (- 1 , (self .heads , - 1 ))
408+
409+ txt_query = self .norm_added_q (txt_query )
410+ txt_key = self .norm_added_k (txt_key )
411+
412+ txt_key_lists = txt_key .chunk (self .overlap_num , 2 )
413+ txt_query_lists = txt_query .chunk (self .overlap_num , 2 )
414+ txt_value_lists = txt_value .chunk (self .overlap_num , 2 )
415+ img_attn_output_works_lists = []
416+ txt_attn_output_works_lists = []
417+
418+ for data_idx , output_qkv_worker in enumerate (output_qkv_workers ):
419+ qkv_list , qkv_worker = output_qkv_worker
420+
421+ txt_key_list = tensor_chunk (txt_key_lists [data_idx ], 2 )
422+ _txt_key = txt_key_list [args .rank ]
423+ _txt_query = tensor_chunk (txt_query_lists [data_idx ], 2 )[args .rank ]
424+ _txt_value = tensor_chunk (txt_value_lists [data_idx ], 2 )[args .rank ]
425+
426+ qkv_worker .wait ()
427+ q , k , v = torch .cat (qkv_list , dim = 1 ).chunk (3 , 0 )
428+
429+ joint_value = torch .cat ([_txt_value , v ], dim = 1 )
430+ joint_value = joint_value .flatten (start_dim = 2 )
431+
432+ joint_query = torch .cat ([_txt_query , q ], dim = 1 )
433+ joint_query = apply_rotary_emb (joint_query , image_rotary_emb )
434+ heads = joint_query .size (- 2 )
435+ joint_query = joint_query .flatten (start_dim = 2 )
436+
437+ joint_key = torch .cat ([_txt_key , k ], dim = 1 )
438+ joint_key = apply_rotary_emb (joint_key , image_rotary_emb )
439+ joint_key = joint_key .flatten (start_dim = 2 )
440+
441+ joint_hidden_states = optimized_attention_masked (joint_query , joint_key , joint_value , heads , attention_mask , transformer_options = transformer_options )
442+ txt_attn_output = joint_hidden_states [:, :seq_txt , :]
443+ img_attn_output = joint_hidden_states [:, seq_txt :, :]
444+ data_shapes = [[b , sp_lens [args .rank ], txt_key_list [rank_i ].size (2 ) * d ] for rank_i in range (args .world_size )]
445+
446+ img_attn_output_list , img_attn_output_worker , _ = all_all_async (img_attn_output , 1 , data_shapes , tensor_chunk (img_attn_output , 1 ))
447+ img_attn_output_works_lists .append ([img_attn_output_list , img_attn_output_worker ])
448+ txt_attn_output_list , txt_attn_output_worker = all_gather_async ([_ .flatten (start_dim = 2 ) for _ in txt_key_list ], txt_attn_output , 2 )
449+ txt_attn_output_works_lists .append ([txt_attn_output_list , txt_attn_output_worker ])
450+
451+ img_outs = []
452+ txt_outs = []
453+ for img_idx , (img_out , img_worker ) in enumerate (img_attn_output_works_lists ):
454+ img_worker .wait ()
455+ img_outs .append (torch .cat (img_out , dim = 2 ))
456+ txt_out , txt_worker = txt_attn_output_works_lists [img_idx ]
457+ txt_worker .wait ()
458+ txt_outs .append (torch .cat (txt_out , dim = 2 ))
459+
460+ img_attn_output = torch .cat (img_outs , dim = 2 )
461+ txt_attn_output = torch .cat (txt_outs , dim = 2 )
462+
317463 img_attn_output = self .to_out [0 ](img_attn_output )
318464 img_attn_output = self .to_out [1 ](img_attn_output )
319465 txt_attn_output = self .to_add_out (txt_attn_output )
@@ -701,6 +847,11 @@ def _forward(
701847 .reshape (1 , - 1 , 1 )
702848 .repeat (x .shape [0 ], 1 , 3 )
703849 )
850+ if args .world_size > 1 :
851+ img_lists = tensor_chunk (hidden_states , - 2 )
852+ hidden_states = img_lists [args .rank ]
853+ sp_len = [img_lists [idx ].size (- 2 ) for idx in range (args .world_size )]
854+ transformer_options ['sp_len' ] = sp_len
704855 ids = torch .cat ((txt_ids , img_ids ), dim = 1 )
705856 image_rotary_emb = self .pe_embedder (ids ).squeeze (1 ).unsqueeze (2 ).to (x .dtype )
706857 del ids , txt_ids , img_ids
@@ -740,6 +891,7 @@ def block_wrap(args):
740891 encoder_hidden_states_mask = encoder_hidden_states_mask ,
741892 temb = args ["vec" ],
742893 image_rotary_emb = args ["pe" ],
894+ transformer_options = args ["transformer_options" ]
743895 )
744896 return out
745897
@@ -756,6 +908,7 @@ def block_wrap(args):
756908 encoder_hidden_states_mask = encoder_hidden_states_mask ,
757909 temb = temb ,
758910 image_rotary_emb = image_rotary_emb ,
911+ transformer_options = transformer_options ,
759912 )
760913 # ControlNet helpers(device/dtype-safe residual adds)
761914 _control = (
@@ -790,6 +943,12 @@ def block_wrap(args):
790943 hidden_states = self .norm_out (hidden_states , temb )
791944 hidden_states = self .proj_out (hidden_states )
792945
946+ if args .world_size > 1 :
947+ bs , _ , ndim = hidden_states .shape
948+ # datas = [hidden_states.new_empty((bs, img_lists[rank_i].size(1), ndim)) for rank_i in range(args.world_size)]
949+ datas = [hidden_states .new_empty ((bs , sp_len [rank_i ], ndim )) for rank_i in range (args .world_size )]
950+ hidden_states = all_gather (datas , hidden_states , - 2 )
951+
793952 hidden_states = hidden_states [:, :num_embeds ].view (
794953 orig_shape [0 ], orig_shape [- 2 ] // 2 , orig_shape [- 1 ] // 2 , orig_shape [1 ], 2 , 2
795954 )
0 commit comments