@@ -104,13 +104,11 @@ def swizzle_data(self, data: torch.Tensor) -> torch.Tensor:
104104 Target layout: [E, num_tiles_k, num_tiles_n, tile_n, tile_k_packed]
105105 This matches the baseline TMA block shape [block_n, packed_block_k] after swapping.
106106 """
107- if data .ndim == 2 :
108- data = data .unsqueeze (0 )
109- if data .ndim != 3 :
110- raise ValueError (f"Expected 2D or 3D canonical data, got { data .ndim } D" )
111-
112107 data = self ._canonical_to_physical (data )
113- E , K_packed , N = data .shape
108+ leading_shape = data .shape [:- 2 ]
109+ E = math .prod (leading_shape )
110+ K_packed , N = data .shape [- 2 :]
111+ data = data .reshape (E , K_packed , N )
114112 tile_k_packed , tile_n , padded_K_packed , padded_N , num_tiles_k , num_tiles_n = \
115113 self ._compute_params (E , K_packed , N )
116114
@@ -139,6 +137,7 @@ def unswizzle_data(self, data: torch.Tensor) -> torch.Tensor:
139137 Input layout: [E, num_tiles_k, num_tiles_n, tile_n, tile_k_packed]
140138 """
141139 E = data .shape [0 ]
140+ leading_shape = self .shape [:- 2 ]
142141 # Recover original shape from self.shape (the logical shape passed to convert_layout)
143142 orig_K_packed = self .shape [- 2 ] // 2 if self .is_fp4 else self .shape [- 2 ]
144143 orig_N = self .shape [- 1 ]
@@ -159,4 +158,6 @@ def unswizzle_data(self, data: torch.Tensor) -> torch.Tensor:
159158 # Trim padding back to original shape
160159 data = data [:, :orig_K_packed , :orig_N ].contiguous ()
161160 data = self ._physical_to_canonical (data )
162- return data if len (self .shape ) == 3 else data .squeeze (0 )
161+ if not leading_shape :
162+ return data .squeeze (0 )
163+ return data .reshape (* leading_shape , data .shape [- 2 ], data .shape [- 1 ])
0 commit comments