-
Notifications
You must be signed in to change notification settings - Fork 26.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Separate q_proj, k_proj, and v_proj for Attention Layers in SAM #33928
Comments
Hey! Yeah this makes sense, the main issue is that the checkpoints we converted and released won’t be aligned with this anymore 🥹 i can however provide you with a super small script to load one into another? We are prioritizing splitted q k v in general! |
I believe you can modify individual parts of attention mechanism(by applying fine tuning techniques to k_proj and keeping the other two unchanged). |
@ArthurZucker Thanks for the feedback! I'd love to get my hands on that script. I hope to spend some time this weekend to weed out any other bugs. |
Ok, brb with a script! |
There you go! from transformers.models.sam.modeling_sam import SamVisionAttention, SamModel, SamVisionLayer
from transformers import SamProcessor
import torch.nn as nn
import torch
from transformers.models.sam import modeling_sam
from PIL import Image
import requests
import matplotlib.pyplot as plt
import numpy as np
from transformers import pipeline
class SamVisionAttentionSplit(SamVisionAttention, nn.Module):
def __init__(self, config, window_size):
super().__init__(config, window_size)
del self.qkv
# Separate q, k, v projections
self.q = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
self.k = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
self.v = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
self._register_load_state_dict_pre_hook(self.split_q_k_v_load_hook)
def split_q_k_v_load_hook(self, state_dict, prefix, *args):
keys_to_delete = []
for key in list(state_dict.keys()):
if "qkv." in key:
# Split q, k, v from the combined projection
q, k, v = state_dict[key].chunk(3, dim=0)
# Replace with individual q, k, v projections
state_dict[key.replace("qkv.", "q.")] = q
state_dict[key.replace("qkv.", "k.")] = k
state_dict[key.replace("qkv.", "v.")] = v
# Mark the old qkv key for deletion
keys_to_delete.append(key)
# Remove old qkv keys
for key in keys_to_delete:
del state_dict[key]
def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
batch_size, height, width, _ = hidden_states.shape
qkv_shapes = (batch_size * self.num_attention_heads, height * width, -1)
query = self.q(hidden_states).reshape((batch_size, height * width,self.num_attention_heads, -1)).permute(0,2,1,3).reshape(qkv_shapes)
key = self.k(hidden_states).reshape((batch_size, height * width,self.num_attention_heads, -1)).permute(0,2,1,3).reshape(qkv_shapes)
value = self.v(hidden_states).reshape((batch_size, height * width,self.num_attention_heads, -1)).permute(0,2,1,3).reshape(qkv_shapes)
attn_weights = (query * self.scale) @ key.transpose(-2, -1)
if self.use_rel_pos:
attn_weights = self.add_decomposed_rel_pos(
attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
)
attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1)
attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1)
attn_output = self.proj(attn_output)
if output_attentions:
outputs = (attn_output, attn_weights)
else:
outputs = (attn_output, None)
return outputs
modeling_sam.SamVisionAttention = SamVisionAttentionSplit
device = "cuda" if torch.cuda.is_available() else "cpu"
model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
generator = pipeline("mask-generation", device = 0, points_per_batch = 256, model = model, image_processor = processor.image_processor)
image_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
outputs = generator(image_url, points_per_batch = 256)
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
raw_image = Image.open(requests.get(image_url, stream=True).raw).convert("RGB")
plt.imshow(np.array(raw_image))
ax = plt.gca()
for mask in outputs["masks"]:
show_mask(mask, ax=ax, random_color=True)
plt.axis("off")
plt.savefig("dummy_test") |
I think these deserves some documentation like "How to hack any transformers model" or something like this, where we show how to use transformers for your custom changes! WDYT? |
Awsome, this is great! I will add this to my repo on finetuning SAM and test this out. Love the idea of a "How to hack any transformers" doc great way to frame it! Let me know if I can help with any of the documentation and I can write something up. |
Yeah for sure, I actually don't use |
Great, I just ran my benchmark and everything works perfectly! I do get a warning when initializing the weights, so I’ll add a note about that in the docs. I’ll go ahead and start on the PR 😉 |
Cool! I'll try to work on the warning as it is not valid! 🤗 |
Feature request
I would like to request the addition of separate projection layers (
q_proj, k_proj, v_proj
) for the attention mechanisms in theSamModel
. Currently, it uses a combined qkv projection, which makes it difficult to apply modifications, such as low-rank adaptations (e.g., LoRA), to specific projections likek_proj
.Motivation
The main motivation behind this request is to enable more flexibility in modifying individual components of the attention mechanism, particularly when adding adapters or low-rank transformations. For example, in certain use cases, we may want to apply LoRA only to the key projection (
k_proj
) without modifying the query or value projections.Your contribution
Here is the changes I believe needs to change, but please provide some feedback if there is a reason this is done the way it is (speed/computation(?)).
qkv
layer into three separate layers:q_proj
,k_proj
, andv_proj
.This change would provide more control for researchers and developers like myself who are experimenting with low-rank adaptations in SAM.
Example possible solution:
The text was updated successfully, but these errors were encountered: