@@ -1383,6 +1383,132 @@ def triton_to_mxfp8_dim1_reference(
1383
1383
scale_e8m0_dim1 ,
1384
1384
)
1385
1385
1386
+ @triton .jit
1387
+ def scale_swizzle (
1388
+ scale_ptr ,
1389
+ scale_rows ,
1390
+ scale_cols ,
1391
+ output_ptr ,
1392
+ input_row_stride ,
1393
+ output_block_stride ,
1394
+ BLOCK_ROWS : tl .constexpr ,
1395
+ BLOCK_COLS : tl .constexpr ,
1396
+ ):
1397
+ """
1398
+ Rearranges tensor data from row-major to block-scaled swizzle format.
1399
+
1400
+ Args:
1401
+ scale_ptr: Pointer to the input scale tensor
1402
+ scale_rows: Number of rows in the scale tensor
1403
+ scale_cols: Number of columns in the scale tensor
1404
+ output_ptr: Pointer to the output tensor
1405
+ input_row_stride: Stride between rows in the input tensor
1406
+ output_block_stride: Stride between blocks in the output tensor
1407
+ BLOCK_ROWS: Number of rows in a tile (compile-time constant)
1408
+ BLOCK_COLS: Number of columns in a tile (compile-time constant)
1409
+ """
1410
+ pid_row = tl .program_id (0 )
1411
+ pid_col = tl .program_id (1 )
1412
+
1413
+ rows = tl .arange (0 , BLOCK_ROWS )[:, None ]
1414
+ cols = tl .arange (0 , BLOCK_COLS )[None , :]
1415
+
1416
+ # Calculate starting row and column for this tile
1417
+ start_row = pid_row * BLOCK_ROWS
1418
+ start_col = pid_col * BLOCK_COLS
1419
+ global_rows = start_row + rows
1420
+ global_cols = start_col + cols
1421
+
1422
+ mask = (global_rows < scale_rows ) & (global_cols < scale_cols )
1423
+
1424
+ input_scales = tl .load (
1425
+ scale_ptr + global_rows * input_row_stride + global_cols ,
1426
+ mask = mask ,
1427
+ other = 0.0 ,
1428
+ )
1429
+
1430
+ r_div_32 = rows // 32
1431
+ r_mod_32 = rows % 32
1432
+
1433
+ # 2) Rearrange to (32, 4, 4) then to final (32, 16) coordinates
1434
+ dest_indices = r_mod_32 * 16 + r_div_32 * 4 + cols
1435
+
1436
+ # Flatten
1437
+ dest_indices_flat = tl .reshape (
1438
+ dest_indices , (BLOCK_ROWS * BLOCK_COLS ), can_reorder = True
1439
+ )
1440
+ flat_scales = tl .reshape (
1441
+ input_scales , (BLOCK_ROWS * BLOCK_COLS ), can_reorder = True
1442
+ )
1443
+
1444
+ # Calculate block offset using provided output block stride
1445
+ LOCAL_NUMEL = BLOCK_ROWS * BLOCK_COLS
1446
+ block_offset = pid_col * LOCAL_NUMEL + (pid_row * output_block_stride )
1447
+
1448
+ tl .store (
1449
+ output_ptr + block_offset + dest_indices_flat ,
1450
+ flat_scales ,
1451
+ )
1452
+
1453
+ def mx_block_rearrange (scale_tensor : torch .Tensor ) -> torch .Tensor :
1454
+ """
1455
+ Rearranges an E8M0 tensor scale from row-major format to block-scaled swizzle format.
1456
+
1457
+ This format is suitable for Tmem as described in NVIDIA documentation:
1458
+ https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
1459
+
1460
+ Args:
1461
+ scale_tensor: Input tensor in row-major format with 8-bit elements
1462
+
1463
+ Returns:
1464
+ Rearranged tensor in block-scaled swizzle format
1465
+ """
1466
+ assert scale_tensor .element_size () == 1 , (
1467
+ "Expected element size to be 1 byte (8 bits)"
1468
+ )
1469
+ assert scale_tensor .is_contiguous (), "Input tensor must be contiguous"
1470
+
1471
+ rows , cols = scale_tensor .shape
1472
+
1473
+ # Calculate blocks needed
1474
+ n_row_blocks = triton .cdiv (rows , 128 )
1475
+ n_col_blocks = triton .cdiv (cols , 4 )
1476
+ padded_rows = n_row_blocks * 128
1477
+ padded_cols = n_col_blocks * 4
1478
+
1479
+ out = scale_tensor .new_empty ((padded_rows , padded_cols ))
1480
+
1481
+ # Input stride (for row-major format)
1482
+ input_row_stride = cols
1483
+
1484
+ # We probably want handle multiple blocks per tile but for now keep it simple
1485
+ BLOCK_ROWS , BLOCK_COLS = 128 , 4
1486
+
1487
+ # Output block stride for the rearranged format
1488
+ output_block_stride = BLOCK_ROWS * BLOCK_COLS * (padded_cols // BLOCK_COLS )
1489
+
1490
+ # Calculate grid dimensions
1491
+ grid = lambda META : (
1492
+ triton .cdiv (padded_rows , BLOCK_ROWS ),
1493
+ triton .cdiv (padded_cols , BLOCK_COLS ),
1494
+ )
1495
+
1496
+ # Launch kernel with added stride parameters
1497
+ # TODO fix before land
1498
+ # wrap_triton(scale_swizzle)[grid](
1499
+ scale_swizzle [grid ](
1500
+ scale_tensor .view (torch .uint8 ),
1501
+ rows ,
1502
+ cols ,
1503
+ out .view (torch .uint8 ),
1504
+ input_row_stride ,
1505
+ output_block_stride ,
1506
+ BLOCK_ROWS = BLOCK_ROWS ,
1507
+ BLOCK_COLS = BLOCK_COLS ,
1508
+ )
1509
+
1510
+ return out
1511
+
1386
1512
else :
1387
1513
1388
1514
def triton_to_mxfp8_dim1 (
@@ -1394,3 +1520,6 @@ def triton_to_mxfp8_dim1_reference(
1394
1520
x_hp : torch .Tensor , block_size
1395
1521
) -> Tuple [torch .Tensor , torch .Tensor ]:
1396
1522
raise AssertionError ("needs torch version 2.8+ and triton" )
1523
+
1524
+ def mx_block_rearrange (scale_tensor : torch .Tensor ) -> torch .Tensor :
1525
+ raise AssertionError ("needs torch version 2.8+ and triton" )
0 commit comments