@@ -300,13 +300,15 @@ def apply_merge(self, model, merge_ratio=None, merge_num=None, cand_distribution
300300 merge_ratio = tome .check_parse_r (num_layers , merge_num , self .embed_res ** 2 , inflect )
301301 if isinstance (cand_distribution , str ):
302302 cand_distribution = cand_distribution .split ('-' )
303- if len (cand_distribution ) == 2 :
304- cand_distribution , r_cand_num = cand_distribution [0 ], int (cand_distribution [- 1 ])
303+ if len (cand_distribution ) == 3 :
304+ cand_distribution , r_cand_num , bias = cand_distribution [0 ], int (cand_distribution [ 1 ]), int (cand_distribution [- 1 ])
305+ elif len (cand_distribution ) == 2 :
306+ cand_distribution , r_cand_num , bias = cand_distribution [0 ], int (cand_distribution [- 1 ]), 1
305307 else :
306- cand_distribution , r_cand_num = cand_distribution [0 ], 5
308+ cand_distribution , r_cand_num , bias = cand_distribution [0 ], 5 , 1
307309 if cand_distribution is not None :
308310 # generate candidate list with the center index
309- bias = 1 if cand_distribution .lower () == "gaussian" else 0
311+ bias = bias if cand_distribution .lower () == "gaussian" else 0
310312 remain_list = [int (((self .embed_res ** 2 - merge_num ) ** 0.5 + i ) ** 2 ) \
311313 for i in range (int (- bias ), r_cand_num - bias )]
312314 merged_list = [self .embed_res ** 2 - num for num in remain_list ]
@@ -359,27 +361,51 @@ def load_pretrained_vit(self, model_path=None):
359361 weight_selection = {}
360362 for key in student_weights .keys ():
361363 if ('block' in key or 'cls_token' in key ) and key in teacher_weights .keys ():
362- weight_selection [key ] = uniform_element_selection (teacher_weights [key ], student_weights [key ]. shape )
364+ weight_selection [key ] = uniform_element_selection (teacher_weights [key ], student_weights [key ])
363365 # load to attention
364366 print ("load pre-trained model for encoder:\n " ,
365367 self .attn .load_state_dict (weight_selection , strict = False ))
366368
367369
368- def uniform_element_selection (tea_weights , stu_shape ):
369- """Large Model Initialization (https://arxiv.org/abs/2311.18823)"""
370- assert tea_weights .dim () == len ( stu_shape ), "Tensors have different number of dimensions"
370+ def uniform_element_selection (tea_weights , stu_weights ):
371+ """Modified and borrowed from ` Large Model Initialization` (https://arxiv.org/abs/2311.18823)"""
372+ assert tea_weights .dim () == stu_weights . dim ( ), "Tensors have different number of dimensions"
371373 tea_weights = tea_weights .clone ()
372- if tea_weights .shape != stu_shape :
374+
375+ def interpolate_1d (x , dim , size , up_mode = 'nearest' ):
376+ if x .shape [dim ] == size :
377+ return x
378+ permute_order = list (range (x .dim ()))
379+ permute_order [dim ] = x .dim ()- 1
380+ permute_order [- 1 ] = dim
381+ x = x .permute (permute_order ).contiguous ()
382+ input_shape = x .shape
383+ x = x .view (- 1 , input_shape [- 1 ]).unsqueeze (1 )
384+ # upsampling
385+ x = torch .nn .functional .interpolate (x , size = size , mode = up_mode ) if up_mode == 'nearest' else \
386+ torch .nn .functional .interpolate (x , size = size , mode = up_mode , align_corners = False )
387+ # reshape back
388+ x = x .squeeze (1 ).view (* input_shape [:- 1 ], size )
389+ inv_order = [0 ] * len (permute_order )
390+ for i , o in enumerate (permute_order ):
391+ inv_order [o ] = i
392+ return x .permute (inv_order ).contiguous ()
393+
394+ if tea_weights .shape != stu_weights .shape :
373395 for dim in range (tea_weights .dim ()):
374- assert tea_weights .shape [dim ] >= stu_shape [dim ], "Teacher's dimension should not be smaller than students'"
375- if tea_weights .shape [dim ] % stu_shape [dim ] == 0 :
376- step = tea_weights .shape [dim ] // stu_shape [dim ]
377- indices = torch .arange (stu_shape [dim ]) * step
396+ if tea_weights .shape [dim ] >= stu_weights .shape [dim ]:
397+ # Teacher's dimension >= students' dimensions
398+ if tea_weights .shape [dim ] % stu_weights .shape [dim ] == 0 :
399+ step = tea_weights .shape [dim ] // stu_weights .shape [dim ]
400+ indices = torch .arange (stu_weights .shape [dim ]) * step
401+ else :
402+ indices = torch .round (torch .linspace (0 , tea_weights .shape [dim ]- 1 , stu_weights .shape [dim ])).long ()
403+ tea_weights = torch .index_select (tea_weights , dim , indices )
378404 else :
379- indices = torch . round ( torch . linspace ( 0 , tea_weights . shape [ dim ] - 1 , stu_shape [ dim ])). long ()
380- tea_weights = torch . index_select (tea_weights , dim , indices )
405+ # Teacher's dimension < students' dimensions
406+ tea_weights = interpolate_1d (tea_weights , dim , stu_weights . shape [ dim ] )
381407 else :
382- assert tea_weights .shape == stu_shape , "Selected weight should be the same as student"
408+ assert tea_weights .shape == stu_weights . shape , "Selected weight should be the same as student"
383409 return tea_weights
384410
385411
@@ -732,7 +758,7 @@ def forward(self, x, quantizer):
732758 from fvcore .nn import FlopCountAnalysis , flop_count_table
733759 resolution = 256
734760 cand_distribution = 'gaussian-6'
735- cand_sample_times = 0
761+ cand_sample_times = 10
736762
737763 # ch, ch_mult = 64, (1, 2, 4, 8)
738764 # num_att_blocks, num_res_blocks, r, merge_num = 12, 4, None, 768
@@ -773,6 +799,7 @@ def forward(self, x, quantizer):
773799 print ('encoder (r={}): {}' .format (model .attn .r , y .shape ))
774800 print (flop_count_table (flop , max_depth = 4 ))
775801 print ('MACs (G) of Encoder: {:.3f}' .format (flop .total () / 1e9 ))
802+ # print(model)
776803
777804 if source is not None :
778805 print ('encoder source matrix:' , source .shape )
0 commit comments