@@ -613,20 +613,29 @@ def _linear_proj_forward_hook(self, module, input, output):
613613 def _estimate_all_head_importance (self ) -> TracedHp .Importance :
614614 """Return the importance for num_attention_heads (num_heads_per_group * num_query_groups)."""
615615 assert self ._activations is not None , "No activations collected for importance estimation."
616- attn_head_importance = self ._activations .view (
617- self .get_hparam ("num_heads_per_group" ).max * self .get_hparam ("num_query_groups" ).max ,
618- self .config .kv_channels ,
619- ).norm (p = 2 , dim = 1 )
616+ attn_head_importance = torch .linalg .vector_norm (
617+ self ._activations .view (
618+ self .get_hparam ("num_heads_per_group" ).max
619+ * self .get_hparam ("num_query_groups" ).max ,
620+ self .config .kv_channels ,
621+ ),
622+ ord = 2 ,
623+ dim = 1 ,
624+ )
620625 return attn_head_importance
621626
622627 def _estimate_query_group_importance (self ) -> TracedHp .Importance :
623628 """Return the importance of the ``num_query_groups`` hparam."""
624629 assert self ._activations is not None , "No activations collected for importance estimation."
625- group_importance = self ._activations .view (
626- self .get_hparam ("num_heads_per_group" ).max ,
627- self .get_hparam ("num_query_groups" ).max ,
628- self .config .kv_channels ,
629- ).norm (p = 2 , dim = (0 , 2 ))
630+ group_importance = torch .linalg .vector_norm (
631+ self ._activations .view (
632+ self .get_hparam ("num_heads_per_group" ).max ,
633+ self .get_hparam ("num_query_groups" ).max ,
634+ self .config .kv_channels ,
635+ ),
636+ ord = 2 ,
637+ dim = (0 , 2 ),
638+ )
630639 return group_importance
631640
632641 def export (self ) -> torch .nn .Module :
0 commit comments