@@ -790,6 +790,84 @@ def get_mixed_freqs(
790790 return rope_embeds .to (dtype )
791791
792792
793+ class RotaryEmbeddingMRope (nn .Module ):
794+ """Interleaved multimodal RoPE (Qwen2-VL style) for vision, matching the reference GenLIP layout.
795+
796+ Drop-in sibling of ``RotaryEmbeddingCat``: ``get_embed(shape) -> [N, 2*dim]``, consumed by
797+ ``apply_rot_embed_cat(..., half=True)`` (no new apply path / no separate sin/cos tensors). The ``dim // 2``
798+ frequency channels are assigned to height/width/temporal axes in a strided ``T,H,W,T,H,W,...`` interleave
799+ (the reference ``apply_interleaved_mrope``): channels ``1,4,7,...`` -> height, ``2,5,8,...`` -> width, and
800+ the remainder -> temporal. ``mrope_section`` sets the strided extent per axis; the actual per-axis channel
801+ *counts* equal ``mrope_section`` only for the standard equal-section configs that tile exactly
802+ (``3*section == dim // 2``, e.g. ``(12,12,12)`` -> 36 channels = 12/12/12), and are otherwise the clamped
803+ interleave (e.g. ``dim=64, (8,12,12)`` -> 11/11/10) -- this matches the reference, which also clamps.
804+
805+ For an image encoder there is no text, so the temporal channels sit at position 0 (inert) and this reduces
806+ to a 2-axis ``(h, w)`` rope -- numerically identical to a checkpoint trained with the reference MRoPE.
807+
808+ Only ``grid_indexing='ij'`` is supported (GenLIP / NaFlex ``(y, x)`` row-major patch order); ``'xy'`` would
809+ require mirroring the timm axial shape-swap and is intentionally not implemented here.
810+ """
811+
812+ def __init__ (
813+ self ,
814+ dim : int ,
815+ mrope_section : Tuple [int , int , int ] = (8 , 12 , 12 ),
816+ temperature : float = 10000. ,
817+ grid_indexing : str = 'ij' ,
818+ device = None ,
819+ dtype = None ,
820+ ):
821+ super ().__init__ ()
822+ assert dim % 2 == 0 , 'dim (head_dim) must be even'
823+ assert sum (mrope_section ) == dim // 2 , \
824+ f"sum(mrope_section)={ sum (mrope_section )} must equal head_dim//2={ dim // 2 } "
825+ assert grid_indexing == 'ij' , \
826+ "RotaryEmbeddingMRope supports grid_indexing='ij' only (GenLIP/NaFlex (y,x) patch order)."
827+ self .dim = dim
828+ self .mrope_section = mrope_section
829+ self .temperature = temperature
830+ self .grid_indexing = grid_indexing
831+
832+ # theta-style frequencies, one per channel (the same vector for every axis, as in Qwen2-VL MRoPE)
833+ inv_freq = 1.0 / (temperature ** (torch .arange (0 , dim , 2 , device = device ).float () / dim )) # [dim//2]
834+ self .register_buffer ('inv_freq' , inv_freq , persistent = False )
835+
836+ # axis id per channel over dim//2: 0 = temporal (inert for images), 1 = height, 2 = width.
837+ # Slice assignment clamps the stop to dim//2, matching the reference `slice(offset, section*3, 3)`.
838+ # The temporal section is the remainder (sec_t is implied by sum == dim//2, not used directly).
839+ _sec_t , sec_h , sec_w = mrope_section
840+ axis = torch .zeros (dim // 2 , dtype = torch .long , device = device ) # default temporal
841+ axis [1 :sec_h * 3 :3 ] = 1 # H at channels {1, 4, 7, ...}
842+ axis [2 :sec_w * 3 :3 ] = 2 # W at channels {2, 5, 8, ...}
843+ self .register_buffer ('axis' , axis , persistent = False )
844+
845+ def get_embed (self , shape : List [int ]) -> torch .Tensor :
846+ """Args:
847+ shape: ``(H, W)`` patch grid.
848+
849+ Returns:
850+ Rope tensor of shape ``[H*W, 2*dim]`` for ``apply_rot_embed_cat(..., half=True)``.
851+ """
852+ h , w = shape
853+ device = self .inv_freq .device
854+ ys , xs = torch .meshgrid (
855+ torch .arange (h , device = device ),
856+ torch .arange (w , device = device ),
857+ indexing = self .grid_indexing ,
858+ )
859+ ys , xs = ys .reshape (- 1 ).float (), xs .reshape (- 1 ).float () # [N]
860+
861+ pos = torch .zeros (ys .shape [0 ], self .dim // 2 , device = device ) # [N, dim//2]
862+ pos [:, self .axis == 1 ] = ys [:, None ] # H channels rotate by row (h)
863+ pos [:, self .axis == 2 ] = xs [:, None ] # W channels rotate by col (w)
864+ # temporal channels keep pos=0 -> angle 0 -> cos=1, sin=0 -> identity (inert)
865+
866+ angles = pos * self .inv_freq # [N, dim//2]
867+ emb = torch .cat ([angles , angles ], dim = - 1 ) # [N, dim]
868+ return torch .cat ([emb .sin (), emb .cos ()], dim = - 1 ) # [N, 2*dim]
869+
870+
793871class RotaryEmbeddingMixed (nn .Module ):
794872 """Rotary position embedding with depth-dependent learnable frequencies.
795873
@@ -1246,6 +1324,7 @@ def create_rope_embed(
12461324 - 'cat': RotaryEmbeddingCat (concatenated sin/cos)
12471325 - 'mixed': RotaryEmbeddingMixed (learnable per-depth frequencies)
12481326 - 'dinov3': RotaryEmbeddingDinoV3 (with coordinate transforms)
1327+ - 'mrope': RotaryEmbeddingMRope (interleaved multimodal RoPE; requires `mrope_section`)
12491328 dim: Total embedding dimension
12501329 num_heads: Number of attention heads
12511330 **kwargs: Additional arguments passed to the specific RoPE class
@@ -1268,5 +1347,9 @@ def create_rope_embed(
12681347 kwargs .pop ('in_pixels' , None ) # doesn't support
12691348 kwargs .pop ('ref_feat_shape' , None ) # doesn't support
12701349 return RotaryEmbeddingDinoV3 (dim = dim // num_heads , ** kwargs )
1350+ elif rope_type == 'mrope' :
1351+ for k in ('in_pixels' , 'ref_feat_shape' , 'rotate_half' ):
1352+ kwargs .pop (k , None ) # mrope builds the half-layout cat tensor itself; these don't apply
1353+ return RotaryEmbeddingMRope (dim = dim // num_heads , ** kwargs )
12711354 else :
12721355 raise ValueError (f"Unknown RoPE type: { rope_type } " )
0 commit comments