Skip to content

Separate q_proj, k_proj, and v_proj for Attention Layers in SAM #33928

Closed
@MagnusS0

Description

@MagnusS0

Feature request

I would like to request the addition of separate projection layers (q_proj, k_proj, v_proj) for the attention mechanisms in the SamModel. Currently, it uses a combined qkv projection, which makes it difficult to apply modifications, such as low-rank adaptations (e.g., LoRA), to specific projections like k_proj.

Motivation

The main motivation behind this request is to enable more flexibility in modifying individual components of the attention mechanism, particularly when adding adapters or low-rank transformations. For example, in certain use cases, we may want to apply LoRA only to the key projection (k_proj) without modifying the query or value projections.

Your contribution

Here is the changes I believe needs to change, but please provide some feedback if there is a reason this is done the way it is (speed/computation(?)).

  • Refactor the existing qkv layer into three separate layers: q_proj, k_proj, and v_proj.
  • Modify the weight initialization to ensure that q_proj, k_proj, and v_proj are properly initialized.
  • Update the weight loading logic to map existing pretrained qkv weights to the new separate projections.

This change would provide more control for researchers and developers like myself who are experimenting with low-rank adaptations in SAM.

Example possible solution:

class SamVisionAttention(nn.Module):
    """Multi-head Attention block with relative position embeddings, with separate q, k, v projections and LoRA for k."""

    def __init__(self, config, window_size):
        super().__init__()
        input_size = (
            (config.image_size // config.patch_size, config.image_size // config.patch_size)
            if window_size == 0
            else (window_size, window_size)
        )

        self.num_attention_heads = config.num_attention_heads
        head_dim = config.hidden_size // config.num_attention_heads
        self.scale = head_dim ** -0.5
        self.dropout = config.attention_dropout

        # Separate q, k, v projections
        self.q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
        self.k_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
        self.v_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
        

        self.o_proj = nn.Linear(config.hidden_size, config.hidden_size)

        self.use_rel_pos = config.use_rel_pos
        if self.use_rel_pos:
            if input_size is None:
                raise ValueError("Input size must be provided if using relative positional encoding.")

            # initialize relative positional embeddings
            self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
            self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))

    def get_rel_pos(self, q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
        """
        Get relative positional embeddings according to the relative positions of
            query and key sizes.

        Args:
            q_size (int):
                size of the query.
            k_size (int):
                size of key k.
            rel_pos (`torch.Tensor`):
                relative position embeddings (L, channel).

        Returns:
            Extracted positional embeddings according to relative positions.
        """
        max_rel_dist = int(2 * max(q_size, k_size) - 1)
        # Interpolate rel pos.
        rel_pos_resized = F.interpolate(
            rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
            size=max_rel_dist,
            mode="linear",
        )
        rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)

        # Scale the coords with short length if shapes for q and k are different.
        q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
        k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
        relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)

        return rel_pos_resized[relative_coords.long()]

    def add_decomposed_rel_pos(
        self,
        attn: torch.Tensor,
        query: torch.Tensor,
        rel_pos_h: torch.Tensor,
        rel_pos_w: torch.Tensor,
        q_size: Tuple[int, int],
        k_size: Tuple[int, int],
    ) -> torch.Tensor:
        """
        Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
        https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py

        Args:
            attn (`torch.Tensor`):
                attention map.
            query (`torch.Tensor`):
                query q in the attention layer with shape (batch_size, query_height * query_width, channel).
            rel_pos_h (`torch.Tensor`):
                relative position embeddings (Lh, channel) for height axis.
            rel_pos_w (`torch.Tensor`):
                relative position embeddings (Lw, channel) for width axis.
            q_size (tuple):
                spatial sequence size of query q with (query_height, query_width).
            k_size (tuple):
                spatial sequence size of key k with (key_height, key_width).

        Returns:
            attn (`torch.Tensor`):
                attention map with added relative positional embeddings.
        """
        query_height, query_width = q_size
        key_height, key_width = k_size
        relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h)
        relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w)

        batch_size, _, dim = query.shape
        reshaped_query = query.reshape(batch_size, query_height, query_width, dim)
        rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
        rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
        attn = attn.reshape(batch_size, query_height, query_width, key_height, key_width)
        attn = attn + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
        attn = attn.reshape(batch_size, query_height * query_width, key_height * key_width)
        return attn

    def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
        batch_size, height, width, _ = hidden_states.shape

        # Separate q, k, v projections
        query = (
            self.q_proj(hidden_states)
            .reshape(batch_size, height * width, self.num_attention_heads, -1)
            .permute(0, 2, 1, 3)
        )
        key = (
            self.k_proj(hidden_states)
            .reshape(batch_size, height * width, self.num_attention_heads, -1)
            .permute(0, 2, 1, 3)
        )

        value = (
            self.v_proj(hidden_states)
            .reshape(batch_size, height * width, self.num_attention_heads, -1)
            .permute(0, 2, 1, 3)
        )

        # Continue with attention mechanism as before
        attn_weights = (query * self.scale) @ key.transpose(-2, -1)

        if self.use_rel_pos:
            attn_weights = self.add_decomposed_rel_pos(
                attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
            )

        attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
        attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)

        attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1)
        attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1)

        attn_output = self.o_yproj(attn_output)

        if output_attentions:
            outputs = (attn_output, attn_weights)
        else:
            outputs = (attn_output, None)

        return outputs

Metadata

Metadata

Assignees

No one assigned

    Labels

    Feature requestRequest for a new featureUsageGeneral questions about the library

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions