Skip to content

Commit c4250a4

Browse files
drisspgsyed-ahmed
andauthored
Fixes MX formats build for blackwell (#2278)
* Fixes MX formats build for blackwell * Adds missing line * Adds missing line --------- Co-authored-by: Syed Tousif Ahmed <[email protected]>
1 parent dd43f16 commit c4250a4

File tree

4 files changed

+83
-28
lines changed

4 files changed

+83
-28
lines changed

setup.py

Lines changed: 75 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,9 @@ def get_extensions():
424424

425425
use_cutlass = False
426426
cutlass_90a_sources = None
427+
cutlass_100a_sources = None
428+
build_for_sm90a = False
429+
build_for_sm100a = False
427430
if use_cuda and not IS_WINDOWS:
428431
use_cutlass = True
429432
cutlass_dir = os.path.join(third_party_path, "cutlass")
@@ -453,32 +456,47 @@ def get_extensions():
453456
)
454457

455458
cuda_arch_flags = _get_cuda_arch_flags()
456-
build_for_sm90 = "-gencode=arch=compute_90,code=sm_90" in cuda_arch_flags
457459
build_for_sm90a = "-gencode=arch=compute_90a,code=sm_90a" in cuda_arch_flags
458-
if build_for_sm90 and not build_for_sm90a:
459-
cutlass_90a_sources = [
460+
build_for_sm100a = "-gencode=arch=compute_100a,code=sm_100a" in cuda_arch_flags
461+
# Define sm90a sources
462+
cutlass_90a_sources = [
463+
os.path.join(
464+
extensions_cuda_dir,
465+
"rowwise_scaled_linear_sparse_cutlass",
466+
"rowwise_scaled_linear_sparse_cutlass_f8f8.cu",
467+
),
468+
os.path.join(
469+
extensions_cuda_dir,
470+
"to_sparse_semi_structured_cutlass_sm9x",
471+
"to_sparse_semi_structured_cutlass_sm9x_f8.cu",
472+
),
473+
os.path.join(extensions_cuda_dir, "activation24", "sparsify24.cu"),
474+
os.path.join(extensions_cuda_dir, "activation24", "sparse_gemm.cu"),
475+
]
476+
for dtypes in ["e4m3e4m3", "e4m3e5m2", "e5m2e4m3", "e5m2e5m2"]:
477+
cutlass_90a_sources.append(
460478
os.path.join(
461479
extensions_cuda_dir,
462480
"rowwise_scaled_linear_sparse_cutlass",
463-
"rowwise_scaled_linear_sparse_cutlass_f8f8.cu",
464-
),
465-
os.path.join(
466-
extensions_cuda_dir,
467-
"to_sparse_semi_structured_cutlass_sm9x",
468-
"to_sparse_semi_structured_cutlass_sm9x_f8.cu",
469-
),
470-
os.path.join(extensions_cuda_dir, "activation24", "sparsify24.cu"),
471-
os.path.join(extensions_cuda_dir, "activation24", "sparse_gemm.cu"),
472-
]
473-
for dtypes in ["e4m3e4m3", "e4m3e5m2", "e5m2e4m3", "e5m2e5m2"]:
474-
cutlass_90a_sources.append(
475-
os.path.join(
476-
extensions_cuda_dir,
477-
"rowwise_scaled_linear_sparse_cutlass",
478-
"rowwise_scaled_linear_sparse_cutlass_" + dtypes + ".cu",
479-
)
481+
"rowwise_scaled_linear_sparse_cutlass_" + dtypes + ".cu",
480482
)
481-
sources = [s for s in sources if s not in cutlass_90a_sources]
483+
)
484+
# Always remove sm90a sources from main sources
485+
sources = [s for s in sources if s not in cutlass_90a_sources]
486+
487+
# Always compile mx_fp_cutlass_kernels.cu ONLY with sm100a architecture
488+
cutlass_100a_sources = [
489+
os.path.join(
490+
extensions_cuda_dir,
491+
"mx_kernels",
492+
"mx_fp_cutlass_kernels.cu",
493+
),
494+
]
495+
# Remove from main sources to prevent compilation with other architectures
496+
sources = [
497+
s for s in sources if os.path.basename(s) != "mx_fp_cutlass_kernels.cu"
498+
]
499+
482500
else:
483501
# Remove CUTLASS-based kernels from the sources list. An
484502
# assumption is that these files will have "cutlass" in its
@@ -492,6 +510,11 @@ def get_extensions():
492510

493511
ext_modules = []
494512
if len(sources) > 0:
513+
# Double-check to ensure mx_fp_cutlass_kernels.cu is not in sources
514+
sources = [
515+
s for s in sources if os.path.basename(s) != "mx_fp_cutlass_kernels.cu"
516+
]
517+
495518
ext_modules.append(
496519
extension(
497520
"torchao._C",
@@ -502,21 +525,48 @@ def get_extensions():
502525
)
503526
)
504527

505-
if cutlass_90a_sources is not None and len(cutlass_90a_sources) > 0:
528+
# Only build the cutlass_90a extension if sm90a is in the architecture flags
529+
if (
530+
cutlass_90a_sources is not None
531+
and len(cutlass_90a_sources) > 0
532+
and build_for_sm90a
533+
):
506534
cutlass_90a_extra_compile_args = copy.deepcopy(extra_compile_args)
507-
cutlass_90a_extra_compile_args["nvcc"].extend(
508-
cuda_arch_flags + ["-gencode=arch=compute_90a,code=sm_90a"]
535+
# Only use sm90a architecture for these sources, ignoring other flags
536+
cutlass_90a_extra_compile_args["nvcc"].append(
537+
"-gencode=arch=compute_90a,code=sm_90a"
509538
)
510539
ext_modules.append(
511540
extension(
512-
"torchao._C",
541+
"torchao._C_cutlass_90a",
513542
cutlass_90a_sources,
514543
py_limited_api=True,
515544
extra_compile_args=cutlass_90a_extra_compile_args,
516545
extra_link_args=extra_link_args,
517546
)
518547
)
519548

549+
# Only build the cutlass_100a extension if sm100a is in the architecture flags
550+
if (
551+
cutlass_100a_sources is not None
552+
and len(cutlass_100a_sources) > 0
553+
and build_for_sm100a
554+
):
555+
cutlass_100a_extra_compile_args = copy.deepcopy(extra_compile_args)
556+
# Only use sm100a architecture for these sources, ignoring cuda_arch_flags
557+
cutlass_100a_extra_compile_args["nvcc"].append(
558+
"-gencode=arch=compute_100a,code=sm_100a"
559+
)
560+
ext_modules.append(
561+
extension(
562+
"torchao._C_cutlass_100a",
563+
cutlass_100a_sources,
564+
py_limited_api=True,
565+
extra_compile_args=cutlass_100a_extra_compile_args,
566+
extra_link_args=extra_link_args,
567+
)
568+
)
569+
520570
# Build CMakeLists from /torchao/experimental - additional options become available : TORCHAO_BUILD_CPU_AARCH64, TORCHAO_BUILD_KLEIDIAI, TORCHAO_BUILD_MPS_OPS, TORCHAO_PARALLEL_BACKEND
521571
if build_macos_arm_auto or os.getenv("BUILD_TORCHAO_EXPERIMENTAL") == "1":
522572
build_options = BuildOptions()

third_party/cutlass

Submodule cutlass updated 530 files

torchao/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525

2626
so_files = list(Path(__file__).parent.glob("_C*.so"))
2727
if len(so_files) > 0:
28-
assert len(so_files) == 1, f"Expected one _C*.so file, found {len(so_files)}"
29-
torch.ops.load_library(str(so_files[0]))
28+
for file in so_files:
29+
torch.ops.load_library(str(file))
3030
from . import ops
3131

3232
# The following library contains CPU kernels from torchao/experimental

torchao/csrc/cuda/mx_kernels/mx_fp_cutlass_kernels.cu

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,11 @@
33
//
44
// This source code is licensed under the BSD 3-Clause license found in the
55
// LICENSE file in the root directory of this source tree.
6+
7+
// Ensure this file is only compiled with sm100a architecture
8+
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ < 1000)
9+
#error "This file must be compiled with compute capability 10.0a or higher (Blackwell architecture)"
10+
#endif
611
#include <torch/library.h>
712

813
#include <ATen/ATen.h>

0 commit comments

Comments
 (0)