Description
Feature request
I would like to request the addition of separate projection layers (q_proj, k_proj, v_proj
) for the attention mechanisms in the SamModel
. Currently, it uses a combined qkv projection, which makes it difficult to apply modifications, such as low-rank adaptations (e.g., LoRA), to specific projections like k_proj
.
Motivation
The main motivation behind this request is to enable more flexibility in modifying individual components of the attention mechanism, particularly when adding adapters or low-rank transformations. For example, in certain use cases, we may want to apply LoRA only to the key projection (k_proj
) without modifying the query or value projections.
Your contribution
Here is the changes I believe needs to change, but please provide some feedback if there is a reason this is done the way it is (speed/computation(?)).
- Refactor the existing
qkv
layer into three separate layers:q_proj
,k_proj
, andv_proj
. - Modify the weight initialization to ensure that q_proj, k_proj, and v_proj are properly initialized.
- Update the weight loading logic to map existing pretrained qkv weights to the new separate projections.
This change would provide more control for researchers and developers like myself who are experimenting with low-rank adaptations in SAM.
Example possible solution:
class SamVisionAttention(nn.Module):
"""Multi-head Attention block with relative position embeddings, with separate q, k, v projections and LoRA for k."""
def __init__(self, config, window_size):
super().__init__()
input_size = (
(config.image_size // config.patch_size, config.image_size // config.patch_size)
if window_size == 0
else (window_size, window_size)
)
self.num_attention_heads = config.num_attention_heads
head_dim = config.hidden_size // config.num_attention_heads
self.scale = head_dim ** -0.5
self.dropout = config.attention_dropout
# Separate q, k, v projections
self.q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
self.k_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
self.v_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
self.o_proj = nn.Linear(config.hidden_size, config.hidden_size)
self.use_rel_pos = config.use_rel_pos
if self.use_rel_pos:
if input_size is None:
raise ValueError("Input size must be provided if using relative positional encoding.")
# initialize relative positional embeddings
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
def get_rel_pos(self, q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
"""
Get relative positional embeddings according to the relative positions of
query and key sizes.
Args:
q_size (int):
size of the query.
k_size (int):
size of key k.
rel_pos (`torch.Tensor`):
relative position embeddings (L, channel).
Returns:
Extracted positional embeddings according to relative positions.
"""
max_rel_dist = int(2 * max(q_size, k_size) - 1)
# Interpolate rel pos.
rel_pos_resized = F.interpolate(
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
size=max_rel_dist,
mode="linear",
)
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
# Scale the coords with short length if shapes for q and k are different.
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
return rel_pos_resized[relative_coords.long()]
def add_decomposed_rel_pos(
self,
attn: torch.Tensor,
query: torch.Tensor,
rel_pos_h: torch.Tensor,
rel_pos_w: torch.Tensor,
q_size: Tuple[int, int],
k_size: Tuple[int, int],
) -> torch.Tensor:
"""
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
Args:
attn (`torch.Tensor`):
attention map.
query (`torch.Tensor`):
query q in the attention layer with shape (batch_size, query_height * query_width, channel).
rel_pos_h (`torch.Tensor`):
relative position embeddings (Lh, channel) for height axis.
rel_pos_w (`torch.Tensor`):
relative position embeddings (Lw, channel) for width axis.
q_size (tuple):
spatial sequence size of query q with (query_height, query_width).
k_size (tuple):
spatial sequence size of key k with (key_height, key_width).
Returns:
attn (`torch.Tensor`):
attention map with added relative positional embeddings.
"""
query_height, query_width = q_size
key_height, key_width = k_size
relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h)
relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w)
batch_size, _, dim = query.shape
reshaped_query = query.reshape(batch_size, query_height, query_width, dim)
rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
attn = attn.reshape(batch_size, query_height, query_width, key_height, key_width)
attn = attn + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
attn = attn.reshape(batch_size, query_height * query_width, key_height * key_width)
return attn
def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
batch_size, height, width, _ = hidden_states.shape
# Separate q, k, v projections
query = (
self.q_proj(hidden_states)
.reshape(batch_size, height * width, self.num_attention_heads, -1)
.permute(0, 2, 1, 3)
)
key = (
self.k_proj(hidden_states)
.reshape(batch_size, height * width, self.num_attention_heads, -1)
.permute(0, 2, 1, 3)
)
value = (
self.v_proj(hidden_states)
.reshape(batch_size, height * width, self.num_attention_heads, -1)
.permute(0, 2, 1, 3)
)
# Continue with attention mechanism as before
attn_weights = (query * self.scale) @ key.transpose(-2, -1)
if self.use_rel_pos:
attn_weights = self.add_decomposed_rel_pos(
attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
)
attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1)
attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1)
attn_output = self.o_yproj(attn_output)
if output_attentions:
outputs = (attn_output, attn_weights)
else:
outputs = (attn_output, None)
return outputs