@@ -1383,6 +1383,124 @@ def triton_to_mxfp8_dim1_reference(
1383
1383
scale_e8m0_dim1 ,
1384
1384
)
1385
1385
1386
+ @triton .jit
1387
+ def scale_swizzle (
1388
+ scale_ptr ,
1389
+ scale_rows ,
1390
+ scale_cols ,
1391
+ output_ptr ,
1392
+ input_row_stride , # Added parameter for input row stride
1393
+ output_block_stride , # Added parameter for output block stride
1394
+ BLOCK_ROWS : tl .constexpr ,
1395
+ BLOCK_COLS : tl .constexpr ,
1396
+ ):
1397
+ """
1398
+ Rearranges tensor data from row-major to block-scaled swizzle format.
1399
+
1400
+ The transformation follows NVIDIA's block scaling factors layout:
1401
+ https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
1402
+ """
1403
+ pid_row = tl .program_id (0 )
1404
+ pid_col = tl .program_id (1 )
1405
+
1406
+ rows = tl .arange (0 , BLOCK_ROWS )[:, None ]
1407
+ cols = tl .arange (0 , BLOCK_COLS )[None , :]
1408
+
1409
+ # Calculate starting row and column for this tile
1410
+ start_row = pid_row * BLOCK_ROWS
1411
+ start_col = pid_col * BLOCK_COLS
1412
+ global_rows = start_row + rows
1413
+ global_cols = start_col + cols
1414
+
1415
+ mask = (global_rows < scale_rows ) & (global_cols < scale_cols )
1416
+
1417
+ input_scales = tl .load (
1418
+ scale_ptr + global_rows * input_row_stride + global_cols ,
1419
+ mask = mask ,
1420
+ other = 0.0 ,
1421
+ )
1422
+
1423
+ # Block rearrangement logic for the _to_blocked_single transformation:
1424
+ # 1) Divide into 4×32 blocks
1425
+ r_div_32 = rows // 32
1426
+ r_mod_32 = rows % 32
1427
+
1428
+ # 2) Rearrange to (32, 4, 4) then to final (32, 16) coordinates
1429
+ # row = r_mod_32, col = (r_div_32 * 4 + inner_col)
1430
+ dest_indices = r_mod_32 * 16 + r_div_32 * 4 + cols
1431
+
1432
+ # Flatten indices for storage
1433
+ dest_indices_flat = tl .reshape (
1434
+ dest_indices , (BLOCK_ROWS * BLOCK_COLS ), can_reorder = True
1435
+ )
1436
+
1437
+ # Calculate block offset using provided output block stride
1438
+ LOCAL_NUMEL = BLOCK_ROWS * BLOCK_COLS
1439
+ block_offset = pid_col * LOCAL_NUMEL + (pid_row * output_block_stride )
1440
+
1441
+ # Store the rearranged values
1442
+ tl .store (
1443
+ output_ptr + block_offset + dest_indices_flat ,
1444
+ tl .reshape (input_scales , (BLOCK_ROWS * BLOCK_COLS ), can_reorder = True ),
1445
+ )
1446
+
1447
+ def triton_block_rearrange (scale_tensor : torch .Tensor ) -> torch .Tensor :
1448
+ """
1449
+ Rearranges an E8M0 tensor scale from row-major format to block-scaled swizzle format.
1450
+
1451
+ This format is suitable for Tmem as described in NVIDIA documentation:
1452
+ https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
1453
+
1454
+ Args:
1455
+ scale_tensor: Input tensor in row-major format with 8-bit elements
1456
+
1457
+ Returns:
1458
+ Rearranged tensor in block-scaled swizzle format
1459
+ """
1460
+ assert scale_tensor .element_size () == 1 , (
1461
+ "Expected element size to be 1 byte (8 bits)"
1462
+ )
1463
+ assert scale_tensor .is_contiguous (), "Input tensor must be contiguous"
1464
+
1465
+ rows , cols = scale_tensor .shape
1466
+
1467
+ # Calculate blocks needed
1468
+ n_row_blocks = triton .cdiv (rows , 128 )
1469
+ n_col_blocks = triton .cdiv (cols , 4 )
1470
+ padded_rows = n_row_blocks * 128
1471
+ padded_cols = n_col_blocks * 4
1472
+
1473
+ out = scale_tensor .new_empty ((padded_rows , padded_cols ))
1474
+
1475
+ # Input stride (for row-major format)
1476
+ input_row_stride = cols
1477
+
1478
+ # We probably want handle multiple blocks per tile but for now keep it simple
1479
+ BLOCK_ROWS , BLOCK_COLS = 128 , 4
1480
+
1481
+ # Output block stride for the rearranged format
1482
+ output_block_stride = BLOCK_ROWS * BLOCK_COLS * (padded_cols // BLOCK_COLS )
1483
+
1484
+ # Calculate grid dimensions
1485
+ grid = lambda META : (
1486
+ triton .cdiv (padded_rows , BLOCK_ROWS ),
1487
+ triton .cdiv (padded_cols , BLOCK_COLS ),
1488
+ )
1489
+
1490
+ # Launch kernel with added stride parameters
1491
+ wrap_triton (scale_swizzle )[grid ](
1492
+ scale_tensor .view (torch .uint8 ),
1493
+ rows ,
1494
+ cols ,
1495
+ out .view (torch .uint8 ),
1496
+ input_row_stride ,
1497
+ output_block_stride ,
1498
+ BLOCK_ROWS = BLOCK_ROWS ,
1499
+ BLOCK_COLS = BLOCK_COLS ,
1500
+ )
1501
+
1502
+ return out
1503
+
1386
1504
else :
1387
1505
1388
1506
def triton_to_mxfp8_dim1 (
0 commit comments