Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Separate q_proj, k_proj, and v_proj for Attention Layers in SAM #33928

Open
MagnusS0 opened this issue Oct 3, 2024 · 11 comments · May be fixed by #33979
Open

Separate q_proj, k_proj, and v_proj for Attention Layers in SAM #33928

MagnusS0 opened this issue Oct 3, 2024 · 11 comments · May be fixed by #33979
Labels
Feature request Request for a new feature Usage General questions about the library

Comments

@MagnusS0
Copy link

MagnusS0 commented Oct 3, 2024

Feature request

I would like to request the addition of separate projection layers (q_proj, k_proj, v_proj) for the attention mechanisms in the SamModel. Currently, it uses a combined qkv projection, which makes it difficult to apply modifications, such as low-rank adaptations (e.g., LoRA), to specific projections like k_proj.

Motivation

The main motivation behind this request is to enable more flexibility in modifying individual components of the attention mechanism, particularly when adding adapters or low-rank transformations. For example, in certain use cases, we may want to apply LoRA only to the key projection (k_proj) without modifying the query or value projections.

Your contribution

Here is the changes I believe needs to change, but please provide some feedback if there is a reason this is done the way it is (speed/computation(?)).

  • Refactor the existing qkv layer into three separate layers: q_proj, k_proj, and v_proj.
  • Modify the weight initialization to ensure that q_proj, k_proj, and v_proj are properly initialized.
  • Update the weight loading logic to map existing pretrained qkv weights to the new separate projections.

This change would provide more control for researchers and developers like myself who are experimenting with low-rank adaptations in SAM.

Example possible solution:

class SamVisionAttention(nn.Module):
    """Multi-head Attention block with relative position embeddings, with separate q, k, v projections and LoRA for k."""

    def __init__(self, config, window_size):
        super().__init__()
        input_size = (
            (config.image_size // config.patch_size, config.image_size // config.patch_size)
            if window_size == 0
            else (window_size, window_size)
        )

        self.num_attention_heads = config.num_attention_heads
        head_dim = config.hidden_size // config.num_attention_heads
        self.scale = head_dim ** -0.5
        self.dropout = config.attention_dropout

        # Separate q, k, v projections
        self.q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
        self.k_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
        self.v_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
        

        self.o_proj = nn.Linear(config.hidden_size, config.hidden_size)

        self.use_rel_pos = config.use_rel_pos
        if self.use_rel_pos:
            if input_size is None:
                raise ValueError("Input size must be provided if using relative positional encoding.")

            # initialize relative positional embeddings
            self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
            self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))

    def get_rel_pos(self, q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
        """
        Get relative positional embeddings according to the relative positions of
            query and key sizes.

        Args:
            q_size (int):
                size of the query.
            k_size (int):
                size of key k.
            rel_pos (`torch.Tensor`):
                relative position embeddings (L, channel).

        Returns:
            Extracted positional embeddings according to relative positions.
        """
        max_rel_dist = int(2 * max(q_size, k_size) - 1)
        # Interpolate rel pos.
        rel_pos_resized = F.interpolate(
            rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
            size=max_rel_dist,
            mode="linear",
        )
        rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)

        # Scale the coords with short length if shapes for q and k are different.
        q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
        k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
        relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)

        return rel_pos_resized[relative_coords.long()]

    def add_decomposed_rel_pos(
        self,
        attn: torch.Tensor,
        query: torch.Tensor,
        rel_pos_h: torch.Tensor,
        rel_pos_w: torch.Tensor,
        q_size: Tuple[int, int],
        k_size: Tuple[int, int],
    ) -> torch.Tensor:
        """
        Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
        https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py

        Args:
            attn (`torch.Tensor`):
                attention map.
            query (`torch.Tensor`):
                query q in the attention layer with shape (batch_size, query_height * query_width, channel).
            rel_pos_h (`torch.Tensor`):
                relative position embeddings (Lh, channel) for height axis.
            rel_pos_w (`torch.Tensor`):
                relative position embeddings (Lw, channel) for width axis.
            q_size (tuple):
                spatial sequence size of query q with (query_height, query_width).
            k_size (tuple):
                spatial sequence size of key k with (key_height, key_width).

        Returns:
            attn (`torch.Tensor`):
                attention map with added relative positional embeddings.
        """
        query_height, query_width = q_size
        key_height, key_width = k_size
        relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h)
        relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w)

        batch_size, _, dim = query.shape
        reshaped_query = query.reshape(batch_size, query_height, query_width, dim)
        rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
        rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
        attn = attn.reshape(batch_size, query_height, query_width, key_height, key_width)
        attn = attn + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
        attn = attn.reshape(batch_size, query_height * query_width, key_height * key_width)
        return attn

    def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
        batch_size, height, width, _ = hidden_states.shape

        # Separate q, k, v projections
        query = (
            self.q_proj(hidden_states)
            .reshape(batch_size, height * width, self.num_attention_heads, -1)
            .permute(0, 2, 1, 3)
        )
        key = (
            self.k_proj(hidden_states)
            .reshape(batch_size, height * width, self.num_attention_heads, -1)
            .permute(0, 2, 1, 3)
        )

        value = (
            self.v_proj(hidden_states)
            .reshape(batch_size, height * width, self.num_attention_heads, -1)
            .permute(0, 2, 1, 3)
        )

        # Continue with attention mechanism as before
        attn_weights = (query * self.scale) @ key.transpose(-2, -1)

        if self.use_rel_pos:
            attn_weights = self.add_decomposed_rel_pos(
                attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
            )

        attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
        attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)

        attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1)
        attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1)

        attn_output = self.o_yproj(attn_output)

        if output_attentions:
            outputs = (attn_output, attn_weights)
        else:
            outputs = (attn_output, None)

        return outputs
@MagnusS0 MagnusS0 added the Feature request Request for a new feature label Oct 3, 2024
@ArthurZucker
Copy link
Collaborator

Hey! Yeah this makes sense, the main issue is that the checkpoints we converted and released won’t be aligned with this anymore 🥹 i can however provide you with a super small script to load one into another?

We are prioritizing splitted q k v in general!

@sangam0406
Copy link

sangam0406 commented Oct 4, 2024

I believe you can modify individual parts of attention mechanism(by applying fine tuning techniques to k_proj and keeping the other two unchanged).
Also we can adapt to specific layers without affecting others.
Also we can do mixed precision-training via torch.cuda.amp which reduce the memory usage and increasing the speed

@MagnusS0
Copy link
Author

MagnusS0 commented Oct 4, 2024

@ArthurZucker Thanks for the feedback! I'd love to get my hands on that script. I hope to spend some time this weekend to weed out any other bugs.

@ArthurZucker ArthurZucker added the Usage General questions about the library label Oct 4, 2024
@ArthurZucker
Copy link
Collaborator

Ok, brb with a script!

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Oct 5, 2024

There you go!

from transformers.models.sam.modeling_sam import SamVisionAttention, SamModel, SamVisionLayer
from transformers import SamProcessor
import torch.nn as nn
import torch
from transformers.models.sam import modeling_sam
from PIL import Image
import requests
import matplotlib.pyplot as plt
import numpy as np
from transformers import pipeline

class SamVisionAttentionSplit(SamVisionAttention, nn.Module):
    def __init__(self, config, window_size):
        super().__init__(config, window_size)
        del self.qkv
        # Separate q, k, v projections
        self.q = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
        self.k = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
        self.v = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
        self._register_load_state_dict_pre_hook(self.split_q_k_v_load_hook)

    def split_q_k_v_load_hook(self, state_dict, prefix, *args):
        keys_to_delete = []
        for key in list(state_dict.keys()):
            if "qkv." in key:
                # Split q, k, v from the combined projection
                q, k, v = state_dict[key].chunk(3, dim=0)
                # Replace with individual q, k, v projections
                state_dict[key.replace("qkv.", "q.")] = q
                state_dict[key.replace("qkv.", "k.")] = k
                state_dict[key.replace("qkv.", "v.")] = v
                # Mark the old qkv key for deletion
                keys_to_delete.append(key)
        
        # Remove old qkv keys
        for key in keys_to_delete:
            del state_dict[key]

    def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
        batch_size, height, width, _ = hidden_states.shape
        qkv_shapes = (batch_size *  self.num_attention_heads,  height * width, -1)
        query = self.q(hidden_states).reshape((batch_size,  height * width,self.num_attention_heads, -1)).permute(0,2,1,3).reshape(qkv_shapes)
        key = self.k(hidden_states).reshape((batch_size,  height * width,self.num_attention_heads, -1)).permute(0,2,1,3).reshape(qkv_shapes)
        value = self.v(hidden_states).reshape((batch_size,  height * width,self.num_attention_heads, -1)).permute(0,2,1,3).reshape(qkv_shapes)

        attn_weights = (query * self.scale) @ key.transpose(-2, -1)

        if self.use_rel_pos:
            attn_weights = self.add_decomposed_rel_pos(
                attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
            )

        attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
        attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
        attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1)
        attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1)
        attn_output = self.proj(attn_output)

        if output_attentions:
            outputs = (attn_output, attn_weights)
        else:
            outputs = (attn_output, None)
        return outputs

modeling_sam.SamVisionAttention = SamVisionAttentionSplit


device = "cuda" if torch.cuda.is_available() else "cpu"
model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

generator =  pipeline("mask-generation", device = 0, points_per_batch = 256, model = model, image_processor = processor.image_processor)
image_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
outputs = generator(image_url, points_per_batch = 256)


def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
    
raw_image = Image.open(requests.get(image_url, stream=True).raw).convert("RGB")
plt.imshow(np.array(raw_image))
ax = plt.gca()
for mask in outputs["masks"]:
    show_mask(mask, ax=ax, random_color=True)
plt.axis("off")
plt.savefig("dummy_test")

@ArthurZucker
Copy link
Collaborator

Produced
dummy_test
for me!

@ArthurZucker
Copy link
Collaborator

I think these deserves some documentation like "How to hack any transformers model" or something like this, where we show how to use transformers for your custom changes! WDYT?

@MagnusS0
Copy link
Author

MagnusS0 commented Oct 5, 2024

Awsome, this is great! I will add this to my repo on finetuning SAM and test this out.
https://github.com/MagnusS0/QLoRA-SAM

Love the idea of a "How to hack any transformers" doc great way to frame it!

Let me know if I can help with any of the documentation and I can write something up.

@ArthurZucker
Copy link
Collaborator

Yeah for sure, I actually don't use transformers as often as our community 🤣 So let's add a section to the doc. Feel free to open a PR with this example and I'll review! 🤗

@MagnusS0
Copy link
Author

MagnusS0 commented Oct 5, 2024

Great, I just ran my benchmark and everything works perfectly! I do get a warning when initializing the weights, so I’ll add a note about that in the docs. I’ll go ahead and start on the PR 😉

@ArthurZucker
Copy link
Collaborator

Cool! I'll try to work on the warning as it is not valid! 🤗

@MagnusS0 MagnusS0 linked a pull request Oct 5, 2024 that will close this issue
3 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Feature request Request for a new feature Usage General questions about the library
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants