@@ -1383,6 +1383,128 @@ def triton_to_mxfp8_dim1_reference(
1383
1383
scale_e8m0_dim1 ,
1384
1384
)
1385
1385
1386
+ @triton .jit
1387
+ def scale_swizzle (
1388
+ scale_ptr ,
1389
+ scale_rows ,
1390
+ scale_cols ,
1391
+ output_ptr ,
1392
+ input_row_stride , # Added parameter for input row stride
1393
+ output_block_stride , # Added parameter for output block stride
1394
+ BLOCK_ROWS : tl .constexpr ,
1395
+ BLOCK_COLS : tl .constexpr ,
1396
+ ):
1397
+ """
1398
+ Rearranges tensor data from row-major to block-scaled swizzle format.
1399
+
1400
+ The transformation follows NVIDIA's block scaling factors layout:
1401
+ https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
1402
+ """
1403
+ pid_row = tl .program_id (0 )
1404
+ pid_col = tl .program_id (1 )
1405
+
1406
+ rows = tl .arange (0 , BLOCK_ROWS )[:, None ]
1407
+ cols = tl .arange (0 , BLOCK_COLS )[None , :]
1408
+
1409
+ # Calculate starting row and column for this tile
1410
+ start_row = pid_row * BLOCK_ROWS
1411
+ start_col = pid_col * BLOCK_COLS
1412
+ global_rows = start_row + rows
1413
+ global_cols = start_col + cols
1414
+
1415
+ mask = (global_rows < scale_rows ) & (global_cols < scale_cols )
1416
+
1417
+ input_scales = tl .load (
1418
+ scale_ptr + global_rows * input_row_stride + global_cols ,
1419
+ mask = mask ,
1420
+ other = 0.0 ,
1421
+ )
1422
+
1423
+ # Block rearrangement logic for the _to_blocked_single transformation:
1424
+ # 1) Divide into 4×32 blocks
1425
+ r_div_32 = rows // 32
1426
+ r_mod_32 = rows % 32
1427
+
1428
+ # 2) Rearrange to (32, 4, 4) then to final (32, 16) coordinates
1429
+ # row = r_mod_32, col = (r_div_32 * 4 + inner_col)
1430
+ dest_indices = r_mod_32 * 16 + r_div_32 * 4 + cols
1431
+
1432
+ # Flatten indices for storage
1433
+ dest_indices_flat = tl .reshape (
1434
+ dest_indices , (BLOCK_ROWS * BLOCK_COLS ), can_reorder = True
1435
+ )
1436
+
1437
+ # Calculate block offset using provided output block stride
1438
+ LOCAL_NUMEL = BLOCK_ROWS * BLOCK_COLS
1439
+ block_offset = pid_col * LOCAL_NUMEL + (pid_row * output_block_stride )
1440
+
1441
+ # Store the rearranged values
1442
+ tl .store (
1443
+ output_ptr + block_offset + dest_indices_flat ,
1444
+ tl .reshape (input_scales , (BLOCK_ROWS * BLOCK_COLS ), can_reorder = True ),
1445
+ )
1446
+
1447
+ def triton_block_rearrange (scale_tensor : torch .Tensor ) -> torch .Tensor :
1448
+ """
1449
+ Rearranges an E8M0 tensor scale from row-major format to block-scaled swizzle format.
1450
+
1451
+ This format is suitable for Tmem as described in NVIDIA documentation:
1452
+ https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
1453
+
1454
+ Args:
1455
+ scale_tensor: Input tensor in row-major format with 8-bit elements
1456
+
1457
+ Returns:
1458
+ Rearranged tensor in block-scaled swizzle format
1459
+ """
1460
+ # Validate input
1461
+ assert scale_tensor .element_size () == 1 , (
1462
+ "Expected element size to be 1 byte (8 bits)"
1463
+ )
1464
+ assert scale_tensor .is_contiguous (), "Input tensor must be contiguous"
1465
+
1466
+ # Get dimensions
1467
+ rows , cols = scale_tensor .shape
1468
+
1469
+ # Calculate blocks needed
1470
+ n_row_blocks = triton .cdiv (rows , 128 )
1471
+ n_col_blocks = triton .cdiv (cols , 4 )
1472
+
1473
+ # Calculate padded dimensions
1474
+ padded_rows = n_row_blocks * 128
1475
+ padded_cols = n_col_blocks * 4
1476
+
1477
+ # Create output tensor
1478
+ out = scale_tensor .new_empty ((padded_rows , padded_cols ))
1479
+
1480
+ # Input stride (for row-major format)
1481
+ input_row_stride = cols
1482
+
1483
+ BLOCK_ROWS , BLOCK_COLS = 128 , 4
1484
+
1485
+ # Output block stride for the rearranged format
1486
+ output_block_stride = BLOCK_ROWS * BLOCK_COLS * (padded_cols // BLOCK_COLS )
1487
+
1488
+ # Calculate grid dimensions
1489
+ grid = lambda META : (
1490
+ triton .cdiv (padded_rows , BLOCK_ROWS ),
1491
+ triton .cdiv (padded_cols , BLOCK_COLS ),
1492
+ )
1493
+
1494
+ # Launch kernel with added stride parameters
1495
+ wrap_triton (scale_swizzle )[grid ](
1496
+ scale_tensor .view (torch .uint8 ),
1497
+ rows ,
1498
+ cols ,
1499
+ out .view (torch .uint8 ),
1500
+ input_row_stride ,
1501
+ output_block_stride ,
1502
+ BLOCK_ROWS = BLOCK_ROWS ,
1503
+ BLOCK_COLS = BLOCK_COLS ,
1504
+ )
1505
+
1506
+ return out
1507
+
1386
1508
else :
1387
1509
1388
1510
def triton_to_mxfp8_dim1 (
0 commit comments