@@ -1383,6 +1383,124 @@ def triton_to_mxfp8_dim1_reference(
1383
1383
scale_e8m0_dim1 ,
1384
1384
)
1385
1385
1386
+ @triton .jit
1387
+ def scale_swizzle (
1388
+ scale_ptr ,
1389
+ scale_rows ,
1390
+ scale_cols ,
1391
+ output_ptr ,
1392
+ input_row_stride ,
1393
+ output_block_stride ,
1394
+ BLOCK_ROWS : tl .constexpr ,
1395
+ BLOCK_COLS : tl .constexpr ,
1396
+ ):
1397
+ """
1398
+ Rearranges tensor data from row-major to block-scaled swizzle format.
1399
+
1400
+ Args:
1401
+ scale_ptr: Pointer to the input scale tensor
1402
+ scale_rows: Number of rows in the scale tensor
1403
+ scale_cols: Number of columns in the scale tensor
1404
+ output_ptr: Pointer to the output tensor
1405
+ input_row_stride: Stride between rows in the input tensor
1406
+ output_block_stride: Stride between blocks in the output tensor
1407
+ BLOCK_ROWS: Number of rows in a tile (compile-time constant)
1408
+ BLOCK_COLS: Number of columns in a tile (compile-time constant)
1409
+ """
1410
+ pid_row = tl .program_id (0 )
1411
+ pid_col = tl .program_id (1 )
1412
+
1413
+ rows = tl .arange (0 , BLOCK_ROWS )[:, None ]
1414
+ cols = tl .arange (0 , BLOCK_COLS )[None , :]
1415
+
1416
+ # Calculate starting row and column for this tile
1417
+ start_row = pid_row * BLOCK_ROWS
1418
+ start_col = pid_col * BLOCK_COLS
1419
+ global_rows = start_row + rows
1420
+ global_cols = start_col + cols
1421
+
1422
+ mask = (global_rows < scale_rows ) & (global_cols < scale_cols )
1423
+
1424
+ input_scales = tl .load (
1425
+ scale_ptr + global_rows * input_row_stride + global_cols ,
1426
+ mask = mask ,
1427
+ other = 0.0 ,
1428
+ )
1429
+
1430
+ r_div_32 = rows // 32
1431
+ r_mod_32 = rows % 32
1432
+
1433
+ # 2) Rearrange to (32, 4, 4) then to final (32, 16) coordinates
1434
+ dest_indices = r_mod_32 * 16 + r_div_32 * 4 + cols
1435
+
1436
+ # Flatten
1437
+ dest_indices_flat = tl .reshape (dest_indices , (BLOCK_ROWS * BLOCK_COLS ))
1438
+ scales_flat = tl .reshape (input_scales , (BLOCK_ROWS * BLOCK_COLS ))
1439
+
1440
+ # Calculate block offset using provided output block stride
1441
+ LOCAL_NUMEL = BLOCK_ROWS * BLOCK_COLS
1442
+ block_offset = pid_col * LOCAL_NUMEL + (pid_row * output_block_stride )
1443
+
1444
+ tl .store (
1445
+ output_ptr + block_offset + dest_indices_flat ,
1446
+ scales_flat ,
1447
+ )
1448
+
1449
+ def mx_block_rearrange (scale_tensor : torch .Tensor ) -> torch .Tensor :
1450
+ """
1451
+ Rearranges an E8M0 tensor scale from row-major format to block-scaled swizzle format.
1452
+
1453
+ This format is suitable for Tmem as described in NVIDIA documentation:
1454
+ https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
1455
+
1456
+ Args:
1457
+ scale_tensor: Input tensor in row-major format with 8-bit elements
1458
+
1459
+ Returns:
1460
+ Rearranged tensor in block-scaled swizzle format
1461
+ """
1462
+ assert scale_tensor .element_size () == 1 , (
1463
+ "Expected element size to be 1 byte (8 bits)"
1464
+ )
1465
+ assert scale_tensor .is_contiguous (), "Input tensor must be contiguous"
1466
+
1467
+ rows , cols = scale_tensor .shape
1468
+
1469
+ # Calculate blocks needed
1470
+ n_row_blocks = triton .cdiv (rows , 128 )
1471
+ n_col_blocks = triton .cdiv (cols , 4 )
1472
+ padded_rows = n_row_blocks * 128
1473
+ padded_cols = n_col_blocks * 4
1474
+
1475
+ out = scale_tensor .new_empty ((padded_rows , padded_cols ))
1476
+
1477
+ # Input stride (for row-major format)
1478
+ input_row_stride = cols
1479
+
1480
+ # We probably want handle multiple blocks per tile but for now keep it simple
1481
+ BLOCK_ROWS , BLOCK_COLS = 128 , 4
1482
+
1483
+ # Output block stride for the rearranged format
1484
+ output_block_stride = BLOCK_ROWS * BLOCK_COLS * (padded_cols // BLOCK_COLS )
1485
+
1486
+ grid = lambda META : (
1487
+ triton .cdiv (padded_rows , BLOCK_ROWS ),
1488
+ triton .cdiv (padded_cols , BLOCK_COLS ),
1489
+ )
1490
+
1491
+ wrap_triton (scale_swizzle )[grid ](
1492
+ scale_tensor .view (torch .uint8 ),
1493
+ rows ,
1494
+ cols ,
1495
+ out .view (torch .uint8 ),
1496
+ input_row_stride ,
1497
+ output_block_stride ,
1498
+ BLOCK_ROWS = BLOCK_ROWS ,
1499
+ BLOCK_COLS = BLOCK_COLS ,
1500
+ )
1501
+
1502
+ return out
1503
+
1386
1504
else :
1387
1505
1388
1506
def triton_to_mxfp8_dim1 (
@@ -1394,3 +1512,6 @@ def triton_to_mxfp8_dim1_reference(
1394
1512
x_hp : torch .Tensor , block_size
1395
1513
) -> Tuple [torch .Tensor , torch .Tensor ]:
1396
1514
raise AssertionError ("needs torch version 2.8+ and triton" )
1515
+
1516
+ def mx_block_rearrange (scale_tensor : torch .Tensor ) -> torch .Tensor :
1517
+ raise AssertionError ("needs torch version 2.8+ and triton" )
0 commit comments