Skip to content

Commit 0647e28

Browse files
committed
Preliminary naflexvit support for genlip image encoder (gated attn + mrope) that could allow remapping of weights
1 parent fbe27d6 commit 0647e28

5 files changed

Lines changed: 125 additions & 5 deletions

File tree

timm/layers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@
135135
RotaryEmbedding,
136136
RotaryEmbeddingCat,
137137
RotaryEmbeddingMixed,
138+
RotaryEmbeddingMRope,
138139
RotaryEmbeddingDinoV3,
139140
get_mixed_freqs,
140141
create_rope_embed,

timm/layers/attention.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def __init__(
6060
qk_norm: bool = False,
6161
scale_norm: bool = False,
6262
proj_bias: bool = True,
63+
gated: bool = False,
6364
attn_drop: float = 0.,
6465
proj_drop: float = 0.,
6566
norm_layer: Optional[Type[nn.Module]] = None,
@@ -77,6 +78,7 @@ def __init__(
7778
qk_norm: Whether to apply normalization to query and key vectors.
7879
scale_norm: Whether to apply normalization to attention output before projection.
7980
proj_bias: Whether to use bias in the output projection.
81+
gated: Apply a per-head sigmoid gate to the attention output (anti attention-sink, GenLIP-style).
8082
attn_drop: Dropout rate applied to the attention weights.
8183
proj_drop: Dropout rate applied after the output projection.
8284
norm_layer: Normalization layer constructor for QK normalization if enabled.
@@ -102,6 +104,7 @@ def __init__(
102104
self.k_norm = norm_layer(head_dim, **dd) if qk_norm else nn.Identity()
103105
self.attn_drop = nn.Dropout(attn_drop)
104106
self.norm = norm_layer(self.attn_dim, **dd) if scale_norm else nn.Identity()
107+
self.gate = nn.Linear(dim, self.attn_dim, bias=qkv_bias, **dd) if gated else None
105108
self.proj = nn.Linear(self.attn_dim, dim_out, bias=proj_bias, **dd)
106109
self.proj_drop = nn.Dropout(proj_drop)
107110

@@ -112,6 +115,7 @@ def forward(
112115
is_causal: bool = False,
113116
) -> torch.Tensor:
114117
B, N, C = x.shape
118+
gate = self.gate(x).sigmoid() if self.gate is not None else None
115119
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
116120
q, k, v = qkv.unbind(0)
117121
q, k = self.q_norm(q), self.k_norm(k)
@@ -134,6 +138,8 @@ def forward(
134138

135139
x = x.transpose(1, 2).reshape(B, N, self.attn_dim)
136140
x = self.norm(x)
141+
if gate is not None:
142+
x = x * gate
137143
x = self.proj(x)
138144
x = self.proj_drop(x)
139145
return x
@@ -165,6 +171,7 @@ def __init__(
165171
scale_norm: bool = False,
166172
proj_bias: bool = True,
167173
rotate_half: bool = False,
174+
gated: bool = False,
168175
device=None,
169176
dtype=None,
170177
):
@@ -218,6 +225,7 @@ def __init__(
218225
self.k_norm = norm_layer(head_dim, **dd) if qk_norm else nn.Identity()
219226
self.attn_drop = nn.Dropout(attn_drop)
220227
self.norm = norm_layer(self.attn_dim, **dd) if scale_norm else nn.Identity()
228+
self.gate = nn.Linear(dim, self.attn_dim, bias=qkv_bias, **dd) if gated else None
221229
self.proj = nn.Linear(self.attn_dim, dim_out, bias=proj_bias, **dd)
222230
self.proj_drop = nn.Dropout(proj_drop)
223231

@@ -240,6 +248,7 @@ def forward(
240248
Tensor of shape (batch_size, sequence_length, dim_out)
241249
"""
242250
B, N, C = x.shape
251+
gate = self.gate(x).sigmoid() if self.gate is not None else None
243252

244253
if self.qkv is not None:
245254
qkv = self.qkv(x)
@@ -277,6 +286,8 @@ def forward(
277286

278287
x = x.transpose(1, 2).reshape(B, N, self.attn_dim)
279288
x = self.norm(x)
289+
if gate is not None:
290+
x = x * gate
280291
x = self.proj(x)
281292
x = self.proj_drop(x)
282293
return x

timm/layers/pos_embed_sincos.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -790,6 +790,84 @@ def get_mixed_freqs(
790790
return rope_embeds.to(dtype)
791791

792792

793+
class RotaryEmbeddingMRope(nn.Module):
794+
"""Interleaved multimodal RoPE (Qwen2-VL style) for vision, matching the reference GenLIP layout.
795+
796+
Drop-in sibling of ``RotaryEmbeddingCat``: ``get_embed(shape) -> [N, 2*dim]``, consumed by
797+
``apply_rot_embed_cat(..., half=True)`` (no new apply path / no separate sin/cos tensors). The ``dim // 2``
798+
frequency channels are assigned to height/width/temporal axes in a strided ``T,H,W,T,H,W,...`` interleave
799+
(the reference ``apply_interleaved_mrope``): channels ``1,4,7,...`` -> height, ``2,5,8,...`` -> width, and
800+
the remainder -> temporal. ``mrope_section`` sets the strided extent per axis; the actual per-axis channel
801+
*counts* equal ``mrope_section`` only for the standard equal-section configs that tile exactly
802+
(``3*section == dim // 2``, e.g. ``(12,12,12)`` -> 36 channels = 12/12/12), and are otherwise the clamped
803+
interleave (e.g. ``dim=64, (8,12,12)`` -> 11/11/10) -- this matches the reference, which also clamps.
804+
805+
For an image encoder there is no text, so the temporal channels sit at position 0 (inert) and this reduces
806+
to a 2-axis ``(h, w)`` rope -- numerically identical to a checkpoint trained with the reference MRoPE.
807+
808+
Only ``grid_indexing='ij'`` is supported (GenLIP / NaFlex ``(y, x)`` row-major patch order); ``'xy'`` would
809+
require mirroring the timm axial shape-swap and is intentionally not implemented here.
810+
"""
811+
812+
def __init__(
813+
self,
814+
dim: int,
815+
mrope_section: Tuple[int, int, int] = (8, 12, 12),
816+
temperature: float = 10000.,
817+
grid_indexing: str = 'ij',
818+
device=None,
819+
dtype=None,
820+
):
821+
super().__init__()
822+
assert dim % 2 == 0, 'dim (head_dim) must be even'
823+
assert sum(mrope_section) == dim // 2, \
824+
f"sum(mrope_section)={sum(mrope_section)} must equal head_dim//2={dim // 2}"
825+
assert grid_indexing == 'ij', \
826+
"RotaryEmbeddingMRope supports grid_indexing='ij' only (GenLIP/NaFlex (y,x) patch order)."
827+
self.dim = dim
828+
self.mrope_section = mrope_section
829+
self.temperature = temperature
830+
self.grid_indexing = grid_indexing
831+
832+
# theta-style frequencies, one per channel (the same vector for every axis, as in Qwen2-VL MRoPE)
833+
inv_freq = 1.0 / (temperature ** (torch.arange(0, dim, 2, device=device).float() / dim)) # [dim//2]
834+
self.register_buffer('inv_freq', inv_freq, persistent=False)
835+
836+
# axis id per channel over dim//2: 0 = temporal (inert for images), 1 = height, 2 = width.
837+
# Slice assignment clamps the stop to dim//2, matching the reference `slice(offset, section*3, 3)`.
838+
# The temporal section is the remainder (sec_t is implied by sum == dim//2, not used directly).
839+
_sec_t, sec_h, sec_w = mrope_section
840+
axis = torch.zeros(dim // 2, dtype=torch.long, device=device) # default temporal
841+
axis[1:sec_h * 3:3] = 1 # H at channels {1, 4, 7, ...}
842+
axis[2:sec_w * 3:3] = 2 # W at channels {2, 5, 8, ...}
843+
self.register_buffer('axis', axis, persistent=False)
844+
845+
def get_embed(self, shape: List[int]) -> torch.Tensor:
846+
"""Args:
847+
shape: ``(H, W)`` patch grid.
848+
849+
Returns:
850+
Rope tensor of shape ``[H*W, 2*dim]`` for ``apply_rot_embed_cat(..., half=True)``.
851+
"""
852+
h, w = shape
853+
device = self.inv_freq.device
854+
ys, xs = torch.meshgrid(
855+
torch.arange(h, device=device),
856+
torch.arange(w, device=device),
857+
indexing=self.grid_indexing,
858+
)
859+
ys, xs = ys.reshape(-1).float(), xs.reshape(-1).float() # [N]
860+
861+
pos = torch.zeros(ys.shape[0], self.dim // 2, device=device) # [N, dim//2]
862+
pos[:, self.axis == 1] = ys[:, None] # H channels rotate by row (h)
863+
pos[:, self.axis == 2] = xs[:, None] # W channels rotate by col (w)
864+
# temporal channels keep pos=0 -> angle 0 -> cos=1, sin=0 -> identity (inert)
865+
866+
angles = pos * self.inv_freq # [N, dim//2]
867+
emb = torch.cat([angles, angles], dim=-1) # [N, dim]
868+
return torch.cat([emb.sin(), emb.cos()], dim=-1) # [N, 2*dim]
869+
870+
793871
class RotaryEmbeddingMixed(nn.Module):
794872
"""Rotary position embedding with depth-dependent learnable frequencies.
795873
@@ -1246,6 +1324,7 @@ def create_rope_embed(
12461324
- 'cat': RotaryEmbeddingCat (concatenated sin/cos)
12471325
- 'mixed': RotaryEmbeddingMixed (learnable per-depth frequencies)
12481326
- 'dinov3': RotaryEmbeddingDinoV3 (with coordinate transforms)
1327+
- 'mrope': RotaryEmbeddingMRope (interleaved multimodal RoPE; requires `mrope_section`)
12491328
dim: Total embedding dimension
12501329
num_heads: Number of attention heads
12511330
**kwargs: Additional arguments passed to the specific RoPE class
@@ -1268,5 +1347,9 @@ def create_rope_embed(
12681347
kwargs.pop('in_pixels', None) # doesn't support
12691348
kwargs.pop('ref_feat_shape', None) # doesn't support
12701349
return RotaryEmbeddingDinoV3(dim=dim // num_heads, **kwargs)
1350+
elif rope_type == 'mrope':
1351+
for k in ('in_pixels', 'ref_feat_shape', 'rotate_half'):
1352+
kwargs.pop(k, None) # mrope builds the half-layout cat tensor itself; these don't apply
1353+
return RotaryEmbeddingMRope(dim=dim // num_heads, **kwargs)
12711354
else:
12721355
raise ValueError(f"Unknown RoPE type: {rope_type}")

timm/models/eva.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def __init__(
122122
qk_norm: bool = False,
123123
scale_norm: bool = True,
124124
rotate_half: bool = False,
125+
gated: bool = False,
125126
device=None,
126127
dtype=None,
127128
):
@@ -176,6 +177,7 @@ def __init__(
176177
self.k_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity()
177178
self.attn_drop = nn.Dropout(attn_drop)
178179
self.norm = norm_layer(attn_dim, **dd) if scale_norm else nn.Identity()
180+
self.gate = nn.Linear(dim, attn_dim, bias=qkv_bias, **dd) if gated else None
179181
self.proj = nn.Linear(attn_dim, dim, **dd)
180182
self.proj_drop = nn.Dropout(proj_drop)
181183

@@ -213,6 +215,7 @@ def forward(
213215
Tensor of shape (batch_size, sequence_length, embedding_dim)
214216
"""
215217
B, N, C = x.shape
218+
gate = self.gate(x).sigmoid() if self.gate is not None else None
216219

217220
if self.qkv is not None:
218221
if self.q_bias is None:
@@ -257,6 +260,8 @@ def forward(
257260

258261
x = x.transpose(1, 2).reshape(B, N, C)
259262
x = self.norm(x)
263+
if gate is not None:
264+
x = x * gate
260265
x = self.proj(x)
261266
x = self.proj_drop(x)
262267
return x
@@ -282,6 +287,7 @@ def __init__(
282287
num_prefix_tokens: int = 1,
283288
attn_type: str = 'eva',
284289
rotate_half: bool = False,
290+
gated_attn: bool = False,
285291
proj_drop: float = 0.,
286292
attn_drop: float = 0.,
287293
drop_path: float = 0.,
@@ -331,6 +337,7 @@ def __init__(
331337
norm_layer=norm_layer,
332338
scale_norm=scale_attn_inner,
333339
rotate_half=rotate_half,
340+
gated=gated_attn,
334341
**dd,
335342
)
336343
self.init_values = init_values
@@ -409,6 +416,7 @@ def __init__(
409416
mlp_ratio: float = 4.,
410417
attn_type: str = 'eva',
411418
rotate_half: bool = False,
419+
gated_attn: bool = False,
412420
swiglu_mlp: bool = False,
413421
swiglu_align_to: int = 0,
414422
scale_mlp: bool = False,
@@ -462,6 +470,7 @@ def __init__(
462470
norm_layer=norm_layer,
463471
scale_norm=scale_attn_inner,
464472
rotate_half=rotate_half,
473+
gated=gated_attn,
465474
**dd,
466475
)
467476
self.norm1 = norm_layer(dim, **dd)

timm/models/naflexvit.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,12 +98,13 @@ class NaFlexVitCfg:
9898
pos_embed_use_grid_sample: bool = False # Whether to use grid_sample for naflex position embedding interpolation
9999

100100
# ROPE specific configuration
101-
rope_type: str = '' # ROPE type: '' or 'none' for no ROPE, 'axial' for standard, 'mixed' for learnable frequencies
101+
rope_type: str = '' # ROPE type: '' / 'none', 'axial', 'mixed', 'dinov3', or 'mrope' (interleaved multimodal)
102102
rope_temperature: float = 10000.0 # Temperature for ROPE frequency computation
103103
rope_ref_feat_shape: Optional[Tuple[int, int]] = None
104104
rope_grid_offset: float = 0. # Grid offset for non-pixel ROPE mode
105105
rope_grid_indexing: str = 'ij' # Grid indexing mode for ROPE ('ij' or 'xy')
106-
rope_rotate_half: bool = False # Use rotate_half layout for ROPE (DINOv3 uses True)
106+
rope_rotate_half: bool = False # Use rotate_half layout for ROPE (DINOv3 and 'mrope' use True)
107+
rope_mrope_section: Optional[Tuple[int, int, int]] = None # (T,H,W) channel split for rope_type='mrope'
107108

108109
# Image processing
109110
dynamic_img_pad: bool = False # Whether to enable dynamic padding for variable resolution
@@ -137,6 +138,7 @@ class NaFlexVitCfg:
137138

138139
# EVA-specific parameters
139140
attn_type: str = 'standard' # Attention type: 'standard', 'eva', 'rope'
141+
attn_gated: bool = False # Apply sigmoid output gate in attention (anti attention-sink, GenLIP-style)
140142
swiglu_mlp: bool = False # Use SwiGLU MLP variant
141143
qkv_fused: bool = True # Whether to use fused QKV projections
142144

@@ -282,7 +284,8 @@ def get_block_fn(cfg: NaFlexVitCfg) -> Callable:
282284
use_eva_features = (
283285
cfg.attn_type in ('eva', 'rope') or
284286
cfg.rope_type not in ('', 'none') or # Any ROPE type requires EVA blocks
285-
cfg.swiglu_mlp
287+
cfg.swiglu_mlp or
288+
cfg.attn_gated # gated attention is implemented on the EVA/rope attention path
286289
)
287290

288291
if use_eva_features:
@@ -300,7 +303,8 @@ def get_block_fn(cfg: NaFlexVitCfg) -> Callable:
300303
scale_attn_inner=cfg.scale_attn_inner_norm,
301304
qkv_fused=cfg.qkv_fused,
302305
num_prefix_tokens=num_prefix_tokens,
303-
rotate_half=cfg.rope_rotate_half,
306+
rotate_half=cfg.rope_rotate_half or cfg.rope_type == 'mrope', # MRoPE requires the half-rotation layout
307+
gated_attn=cfg.attn_gated,
304308
)
305309
else:
306310
# Standard ViT block
@@ -1194,7 +1198,9 @@ def __init__(
11941198
self.rope: Optional[nn.Module] = None
11951199
self.rope_is_mixed = False
11961200
if cfg.rope_type and cfg.rope_type != 'none':
1197-
from timm.layers.pos_embed_sincos import RotaryEmbeddingCat, RotaryEmbeddingDinoV3, RotaryEmbeddingMixed
1201+
from timm.layers.pos_embed_sincos import (
1202+
RotaryEmbeddingCat, RotaryEmbeddingDinoV3, RotaryEmbeddingMixed, RotaryEmbeddingMRope,
1203+
)
11981204
if cfg.rope_type == 'mixed':
11991205
self.rope = RotaryEmbeddingMixed(
12001206
cfg.embed_dim,
@@ -1228,6 +1234,16 @@ def __init__(
12281234
**dd,
12291235
)
12301236
self.rope_is_mixed = False
1237+
elif cfg.rope_type == 'mrope':
1238+
assert cfg.rope_mrope_section is not None, "rope_type='mrope' requires cfg.rope_mrope_section"
1239+
self.rope = RotaryEmbeddingMRope(
1240+
cfg.embed_dim // cfg.num_heads,
1241+
mrope_section=cfg.rope_mrope_section,
1242+
temperature=cfg.rope_temperature,
1243+
grid_indexing=cfg.rope_grid_indexing,
1244+
**dd,
1245+
)
1246+
self.rope_is_mixed = False
12311247
else:
12321248
raise ValueError(f"Unknown rope_type: {cfg.rope_type}")
12331249

0 commit comments

Comments
 (0)