@@ -1383,6 +1383,133 @@ def triton_to_mxfp8_dim1_reference(
1383
1383
scale_e8m0_dim1 ,
1384
1384
)
1385
1385
1386
+ @triton .jit
1387
+ def scale_swizzle (
1388
+ scale_ptr ,
1389
+ scale_rows ,
1390
+ scale_cols ,
1391
+ output_ptr ,
1392
+ input_row_stride ,
1393
+ output_block_stride ,
1394
+ BLOCK_ROWS : tl .constexpr ,
1395
+ BLOCK_COLS : tl .constexpr ,
1396
+ ):
1397
+ """
1398
+ Rearranges tensor data from row-major to block-scaled swizzle format.
1399
+
1400
+ Args:
1401
+ scale_ptr: Pointer to the input scale tensor
1402
+ scale_rows: Number of rows in the scale tensor
1403
+ scale_cols: Number of columns in the scale tensor
1404
+ output_ptr: Pointer to the output tensor
1405
+ input_row_stride: Stride between rows in the input tensor
1406
+ output_block_stride: Stride between blocks in the output tensor
1407
+ BLOCK_ROWS: Number of rows in a tile (compile-time constant)
1408
+ BLOCK_COLS: Number of columns in a tile (compile-time constant)
1409
+ """
1410
+ pid_row = tl .program_id (0 )
1411
+ pid_col = tl .program_id (1 )
1412
+
1413
+ rows = tl .arange (0 , BLOCK_ROWS )[:, None ]
1414
+ cols = tl .arange (0 , BLOCK_COLS )[None , :]
1415
+
1416
+ # Calculate starting row and column for this tile
1417
+ start_row = pid_row * BLOCK_ROWS
1418
+ start_col = pid_col * BLOCK_COLS
1419
+ global_rows = start_row + rows
1420
+ global_cols = start_col + cols
1421
+
1422
+ mask = (global_rows < scale_rows ) & (global_cols < scale_cols )
1423
+
1424
+ input_scales = tl .load (
1425
+ scale_ptr + global_rows * input_row_stride + global_cols ,
1426
+ mask = mask ,
1427
+ other = 0.0 ,
1428
+ )
1429
+
1430
+ # Block rearrangement logic for the _to_blocked_single transformation:
1431
+ # 1) Divide into 4×32 blocks
1432
+ r_div_32 = rows // 32
1433
+ r_mod_32 = rows % 32
1434
+
1435
+ # 2) Rearrange to (32, 4, 4) then to final (32, 16) coordinates
1436
+ # row = r_mod_32, col = (r_div_32 * 4 + inner_col)
1437
+ dest_indices = r_mod_32 * 16 + r_div_32 * 4 + cols
1438
+
1439
+ # Flatten indices for storage
1440
+ dest_indices_flat = tl .reshape (
1441
+ dest_indices , (BLOCK_ROWS * BLOCK_COLS ), can_reorder = True
1442
+ )
1443
+
1444
+ # Calculate block offset using provided output block stride
1445
+ LOCAL_NUMEL = BLOCK_ROWS * BLOCK_COLS
1446
+ block_offset = pid_col * LOCAL_NUMEL + (pid_row * output_block_stride )
1447
+
1448
+ # Store the rearranged values
1449
+ tl .store (
1450
+ output_ptr + block_offset + dest_indices_flat ,
1451
+ tl .reshape (input_scales , (BLOCK_ROWS * BLOCK_COLS ), can_reorder = True ),
1452
+ )
1453
+
1454
+ def mx_block_rearrange (scale_tensor : torch .Tensor ) -> torch .Tensor :
1455
+ """
1456
+ Rearranges an E8M0 tensor scale from row-major format to block-scaled swizzle format.
1457
+
1458
+ This format is suitable for Tmem as described in NVIDIA documentation:
1459
+ https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
1460
+
1461
+ Args:
1462
+ scale_tensor: Input tensor in row-major format with 8-bit elements
1463
+
1464
+ Returns:
1465
+ Rearranged tensor in block-scaled swizzle format
1466
+ """
1467
+ assert scale_tensor .element_size () == 1 , (
1468
+ "Expected element size to be 1 byte (8 bits)"
1469
+ )
1470
+ assert scale_tensor .is_contiguous (), "Input tensor must be contiguous"
1471
+
1472
+ rows , cols = scale_tensor .shape
1473
+
1474
+ # Calculate blocks needed
1475
+ n_row_blocks = triton .cdiv (rows , 128 )
1476
+ n_col_blocks = triton .cdiv (cols , 4 )
1477
+ padded_rows = n_row_blocks * 128
1478
+ padded_cols = n_col_blocks * 4
1479
+
1480
+ out = scale_tensor .new_empty ((padded_rows , padded_cols ))
1481
+
1482
+ # Input stride (for row-major format)
1483
+ input_row_stride = cols
1484
+
1485
+ # We probably want handle multiple blocks per tile but for now keep it simple
1486
+ BLOCK_ROWS , BLOCK_COLS = 128 , 4
1487
+
1488
+ # Output block stride for the rearranged format
1489
+ output_block_stride = BLOCK_ROWS * BLOCK_COLS * (padded_cols // BLOCK_COLS )
1490
+
1491
+ # Calculate grid dimensions
1492
+ grid = lambda META : (
1493
+ triton .cdiv (padded_rows , BLOCK_ROWS ),
1494
+ triton .cdiv (padded_cols , BLOCK_COLS ),
1495
+ )
1496
+
1497
+ # Launch kernel with added stride parameters
1498
+ # TODO fix before land
1499
+ # wrap_triton(scale_swizzle)[grid](
1500
+ scale_swizzle [grid ](
1501
+ scale_tensor .view (torch .uint8 ),
1502
+ rows ,
1503
+ cols ,
1504
+ out .view (torch .uint8 ),
1505
+ input_row_stride ,
1506
+ output_block_stride ,
1507
+ BLOCK_ROWS = BLOCK_ROWS ,
1508
+ BLOCK_COLS = BLOCK_COLS ,
1509
+ )
1510
+
1511
+ return out
1512
+
1386
1513
else :
1387
1514
1388
1515
def triton_to_mxfp8_dim1 (
@@ -1394,3 +1521,6 @@ def triton_to_mxfp8_dim1_reference(
1394
1521
x_hp : torch .Tensor , block_size
1395
1522
) -> Tuple [torch .Tensor , torch .Tensor ]:
1396
1523
raise AssertionError ("needs torch version 2.8+ and triton" )
1524
+
1525
+ def mx_block_rearrange (scale_tensor : torch .Tensor ) -> torch .Tensor :
1526
+ raise AssertionError ("needs torch version 2.8+ and triton" )
0 commit comments