@@ -424,6 +424,9 @@ def get_extensions():
424
424
425
425
use_cutlass = False
426
426
cutlass_90a_sources = None
427
+ cutlass_100a_sources = None
428
+ build_for_sm90a = False
429
+ build_for_sm100a = False
427
430
if use_cuda and not IS_WINDOWS :
428
431
use_cutlass = True
429
432
cutlass_dir = os .path .join (third_party_path , "cutlass" )
@@ -453,32 +456,47 @@ def get_extensions():
453
456
)
454
457
455
458
cuda_arch_flags = _get_cuda_arch_flags ()
456
- build_for_sm90 = "-gencode=arch=compute_90,code=sm_90" in cuda_arch_flags
457
459
build_for_sm90a = "-gencode=arch=compute_90a,code=sm_90a" in cuda_arch_flags
458
- if build_for_sm90 and not build_for_sm90a :
459
- cutlass_90a_sources = [
460
+ build_for_sm100a = "-gencode=arch=compute_100a,code=sm_100a" in cuda_arch_flags
461
+ # Define sm90a sources
462
+ cutlass_90a_sources = [
463
+ os .path .join (
464
+ extensions_cuda_dir ,
465
+ "rowwise_scaled_linear_sparse_cutlass" ,
466
+ "rowwise_scaled_linear_sparse_cutlass_f8f8.cu" ,
467
+ ),
468
+ os .path .join (
469
+ extensions_cuda_dir ,
470
+ "to_sparse_semi_structured_cutlass_sm9x" ,
471
+ "to_sparse_semi_structured_cutlass_sm9x_f8.cu" ,
472
+ ),
473
+ os .path .join (extensions_cuda_dir , "activation24" , "sparsify24.cu" ),
474
+ os .path .join (extensions_cuda_dir , "activation24" , "sparse_gemm.cu" ),
475
+ ]
476
+ for dtypes in ["e4m3e4m3" , "e4m3e5m2" , "e5m2e4m3" , "e5m2e5m2" ]:
477
+ cutlass_90a_sources .append (
460
478
os .path .join (
461
479
extensions_cuda_dir ,
462
480
"rowwise_scaled_linear_sparse_cutlass" ,
463
- "rowwise_scaled_linear_sparse_cutlass_f8f8.cu" ,
464
- ),
465
- os .path .join (
466
- extensions_cuda_dir ,
467
- "to_sparse_semi_structured_cutlass_sm9x" ,
468
- "to_sparse_semi_structured_cutlass_sm9x_f8.cu" ,
469
- ),
470
- os .path .join (extensions_cuda_dir , "activation24" , "sparsify24.cu" ),
471
- os .path .join (extensions_cuda_dir , "activation24" , "sparse_gemm.cu" ),
472
- ]
473
- for dtypes in ["e4m3e4m3" , "e4m3e5m2" , "e5m2e4m3" , "e5m2e5m2" ]:
474
- cutlass_90a_sources .append (
475
- os .path .join (
476
- extensions_cuda_dir ,
477
- "rowwise_scaled_linear_sparse_cutlass" ,
478
- "rowwise_scaled_linear_sparse_cutlass_" + dtypes + ".cu" ,
479
- )
481
+ "rowwise_scaled_linear_sparse_cutlass_" + dtypes + ".cu" ,
480
482
)
481
- sources = [s for s in sources if s not in cutlass_90a_sources ]
483
+ )
484
+ # Always remove sm90a sources from main sources
485
+ sources = [s for s in sources if s not in cutlass_90a_sources ]
486
+
487
+ # Always compile mx_fp_cutlass_kernels.cu ONLY with sm100a architecture
488
+ cutlass_100a_sources = [
489
+ os .path .join (
490
+ extensions_cuda_dir ,
491
+ "mx_kernels" ,
492
+ "mx_fp_cutlass_kernels.cu" ,
493
+ ),
494
+ ]
495
+ # Remove from main sources to prevent compilation with other architectures
496
+ sources = [
497
+ s for s in sources if os .path .basename (s ) != "mx_fp_cutlass_kernels.cu"
498
+ ]
499
+
482
500
else :
483
501
# Remove CUTLASS-based kernels from the sources list. An
484
502
# assumption is that these files will have "cutlass" in its
@@ -492,6 +510,11 @@ def get_extensions():
492
510
493
511
ext_modules = []
494
512
if len (sources ) > 0 :
513
+ # Double-check to ensure mx_fp_cutlass_kernels.cu is not in sources
514
+ sources = [
515
+ s for s in sources if os .path .basename (s ) != "mx_fp_cutlass_kernels.cu"
516
+ ]
517
+
495
518
ext_modules .append (
496
519
extension (
497
520
"torchao._C" ,
@@ -502,21 +525,48 @@ def get_extensions():
502
525
)
503
526
)
504
527
505
- if cutlass_90a_sources is not None and len (cutlass_90a_sources ) > 0 :
528
+ # Only build the cutlass_90a extension if sm90a is in the architecture flags
529
+ if (
530
+ cutlass_90a_sources is not None
531
+ and len (cutlass_90a_sources ) > 0
532
+ and build_for_sm90a
533
+ ):
506
534
cutlass_90a_extra_compile_args = copy .deepcopy (extra_compile_args )
507
- cutlass_90a_extra_compile_args ["nvcc" ].extend (
508
- cuda_arch_flags + ["-gencode=arch=compute_90a,code=sm_90a" ]
535
+ # Only use sm90a architecture for these sources, ignoring other flags
536
+ cutlass_90a_extra_compile_args ["nvcc" ].append (
537
+ "-gencode=arch=compute_90a,code=sm_90a"
509
538
)
510
539
ext_modules .append (
511
540
extension (
512
- "torchao._C " ,
541
+ "torchao._C_cutlass_90a " ,
513
542
cutlass_90a_sources ,
514
543
py_limited_api = True ,
515
544
extra_compile_args = cutlass_90a_extra_compile_args ,
516
545
extra_link_args = extra_link_args ,
517
546
)
518
547
)
519
548
549
+ # Only build the cutlass_100a extension if sm100a is in the architecture flags
550
+ if (
551
+ cutlass_100a_sources is not None
552
+ and len (cutlass_100a_sources ) > 0
553
+ and build_for_sm100a
554
+ ):
555
+ cutlass_100a_extra_compile_args = copy .deepcopy (extra_compile_args )
556
+ # Only use sm100a architecture for these sources, ignoring cuda_arch_flags
557
+ cutlass_100a_extra_compile_args ["nvcc" ].append (
558
+ "-gencode=arch=compute_100a,code=sm_100a"
559
+ )
560
+ ext_modules .append (
561
+ extension (
562
+ "torchao._C_cutlass_100a" ,
563
+ cutlass_100a_sources ,
564
+ py_limited_api = True ,
565
+ extra_compile_args = cutlass_100a_extra_compile_args ,
566
+ extra_link_args = extra_link_args ,
567
+ )
568
+ )
569
+
520
570
# Build CMakeLists from /torchao/experimental - additional options become available : TORCHAO_BUILD_CPU_AARCH64, TORCHAO_BUILD_KLEIDIAI, TORCHAO_BUILD_MPS_OPS, TORCHAO_PARALLEL_BACKEND
521
571
if build_macos_arm_auto or os .getenv ("BUILD_TORCHAO_EXPERIMENTAL" ) == "1" :
522
572
build_options = BuildOptions ()
0 commit comments