Skip to content

Commit a6a79ec

Browse files
committed
fix: adapt tome function signatures to HF classes
1 parent 0a68987 commit a6a79ec

File tree

1 file changed

+21
-9
lines changed

1 file changed

+21
-9
lines changed

src/pruna/algorithms/token_merging.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import annotations
1515

1616
import math
17-
from typing import Any, Callable
17+
from typing import Any, Callable, Optional, Tuple, Union
1818

1919
import torch
2020
from ConfigSpace import UniformIntegerHyperparameter
@@ -230,8 +230,9 @@ class ToMeViTSelfAttention(_HFViTSelfAttention):
230230
def forward(
231231
self,
232232
hidden_states: torch.Tensor,
233-
head_mask: torch.Tensor | None = None,
234-
) -> tuple[torch.Tensor, torch.Tensor]:
233+
head_mask: Optional[torch.Tensor] = None,
234+
output_attentions: bool = False,
235+
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
235236
"""Forward pass with proportional attention and key-metric storage."""
236237
batch_size = hidden_states.shape[0]
237238
new_shape = (batch_size, -1, self.num_attention_heads, self.attention_head_size)
@@ -256,10 +257,12 @@ def forward(
256257
context_layer = (attn_probs @ value_layer).transpose(1, 2)
257258
context_layer = context_layer.reshape(batch_size, -1, self.all_head_size)
258259

260+
outputs = (context_layer, attn_probs) if output_attentions else (context_layer,)
261+
259262
# Store the key mean as the similarity metric for token merging.
260263
self._tome_info["metric"] = key_layer.mean(1)
261264

262-
return context_layer, attn_weights
265+
return outputs
263266

264267
class ToMeViTLayer(_HFViTLayer):
265268
"""
@@ -276,12 +279,19 @@ class ToMeViTLayer(_HFViTLayer):
276279
def forward(
277280
self,
278281
hidden_states: torch.Tensor,
279-
head_mask: torch.Tensor | None = None,
280-
) -> torch.Tensor:
282+
head_mask: Optional[torch.Tensor] = None,
283+
output_attentions: bool = False,
284+
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
281285
"""Forward pass with token merging between attention and MLP."""
282286
# --- self-attention + first residual ---
283-
hidden_states_norm = self.layernorm_before(hidden_states)
284-
attention_output = self.attention(hidden_states_norm, head_mask)
287+
self_attention_outputs = self.attention(
288+
self.layernorm_before(hidden_states),
289+
head_mask,
290+
output_attentions=output_attentions,
291+
)
292+
attention_output = self_attention_outputs[0]
293+
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
294+
285295
hidden_states = attention_output + hidden_states
286296

287297
# --- token merging ---
@@ -303,7 +313,9 @@ def forward(
303313
layer_output = self.intermediate(layer_output)
304314
layer_output = self.output(layer_output, hidden_states)
305315

306-
return layer_output
316+
outputs = (layer_output,) + outputs
317+
318+
return outputs
307319

308320
except ImportError:
309321
ToMeViTSelfAttention = None

0 commit comments

Comments
 (0)