@@ -44,7 +44,7 @@ def _get_rel_pos_bias(self, window_size):
4444 old_sub_table = old_relative_position_bias_table [:old_num_relative_distance - 3 ]
4545
4646 old_sub_table = old_sub_table .reshape (1 , old_width , old_height , - 1 ).permute (0 , 3 , 1 , 2 )
47- new_sub_table = F .interpolate (old_sub_table , size = (new_height , new_width ), mode = "bilinear" )
47+ new_sub_table = F .interpolate (old_sub_table , size = (int ( new_height ), int ( new_width ) ), mode = "bilinear" )
4848 new_sub_table = new_sub_table .permute (0 , 2 , 3 , 1 ).reshape (new_num_relative_distance - 3 , - 1 )
4949
5050 new_relative_position_bias_table = torch .cat (
@@ -96,12 +96,12 @@ def block_forward(self, x, resolution, shared_rel_pos_bias: Optional[torch.Tenso
9696 Modification of timm.models.beit.py: Block.forward to support arbitrary window sizes.
9797 """
9898 if self .gamma_1 is None :
99- x = x + self .drop_path (self .attn (self .norm1 (x ), resolution , shared_rel_pos_bias = shared_rel_pos_bias ))
100- x = x + self .drop_path (self .mlp (self .norm2 (x )))
99+ x = x + self .drop_path1 (self .attn (self .norm1 (x ), resolution , shared_rel_pos_bias = shared_rel_pos_bias ))
100+ x = x + self .drop_path2 (self .mlp (self .norm2 (x )))
101101 else :
102- x = x + self .drop_path (self .gamma_1 * self .attn (self .norm1 (x ), resolution ,
102+ x = x + self .drop_path1 (self .gamma_1 * self .attn (self .norm1 (x ), resolution ,
103103 shared_rel_pos_bias = shared_rel_pos_bias ))
104- x = x + self .drop_path (self .gamma_2 * self .mlp (self .norm2 (x )))
104+ x = x + self .drop_path2 (self .gamma_2 * self .mlp (self .norm2 (x )))
105105 return x
106106
107107
0 commit comments