1414from __future__ import annotations
1515
1616import math
17- from typing import Any , Callable
17+ from typing import Any , Callable , Optional , Tuple , Union
1818
1919import torch
2020from ConfigSpace import UniformIntegerHyperparameter
@@ -230,8 +230,9 @@ class ToMeViTSelfAttention(_HFViTSelfAttention):
230230 def forward (
231231 self ,
232232 hidden_states : torch .Tensor ,
233- head_mask : torch .Tensor | None = None ,
234- ) -> tuple [torch .Tensor , torch .Tensor ]:
233+ head_mask : Optional [torch .Tensor ] = None ,
234+ output_attentions : bool = False ,
235+ ) -> Union [Tuple [torch .Tensor , torch .Tensor ], Tuple [torch .Tensor ]]:
235236 """Forward pass with proportional attention and key-metric storage."""
236237 batch_size = hidden_states .shape [0 ]
237238 new_shape = (batch_size , - 1 , self .num_attention_heads , self .attention_head_size )
@@ -256,10 +257,12 @@ def forward(
256257 context_layer = (attn_probs @ value_layer ).transpose (1 , 2 )
257258 context_layer = context_layer .reshape (batch_size , - 1 , self .all_head_size )
258259
260+ outputs = (context_layer , attn_probs ) if output_attentions else (context_layer ,)
261+
259262 # Store the key mean as the similarity metric for token merging.
260263 self ._tome_info ["metric" ] = key_layer .mean (1 )
261264
262- return context_layer , attn_weights
265+ return outputs
263266
264267 class ToMeViTLayer (_HFViTLayer ):
265268 """
@@ -276,12 +279,19 @@ class ToMeViTLayer(_HFViTLayer):
276279 def forward (
277280 self ,
278281 hidden_states : torch .Tensor ,
279- head_mask : torch .Tensor | None = None ,
280- ) -> torch .Tensor :
282+ head_mask : Optional [torch .Tensor ] = None ,
283+ output_attentions : bool = False ,
284+ ) -> Union [Tuple [torch .Tensor , torch .Tensor ], Tuple [torch .Tensor ]]:
281285 """Forward pass with token merging between attention and MLP."""
282286 # --- self-attention + first residual ---
283- hidden_states_norm = self .layernorm_before (hidden_states )
284- attention_output = self .attention (hidden_states_norm , head_mask )
287+ self_attention_outputs = self .attention (
288+ self .layernorm_before (hidden_states ),
289+ head_mask ,
290+ output_attentions = output_attentions ,
291+ )
292+ attention_output = self_attention_outputs [0 ]
293+ outputs = self_attention_outputs [1 :] # add self attentions if we output attention weights
294+
285295 hidden_states = attention_output + hidden_states
286296
287297 # --- token merging ---
@@ -303,7 +313,9 @@ def forward(
303313 layer_output = self .intermediate (layer_output )
304314 layer_output = self .output (layer_output , hidden_states )
305315
306- return layer_output
316+ outputs = (layer_output ,) + outputs
317+
318+ return outputs
307319
308320except ImportError :
309321 ToMeViTSelfAttention = None
0 commit comments