Skip to content

Commit 763f54c

Browse files
committed
Fix example runtimes
1 parent acc7b79 commit 763f54c

File tree

2 files changed

+10
-14
lines changed

2 files changed

+10
-14
lines changed

examples/65_distributed_gemm/65_distributed_gemm.cu

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ using namespace cute;
132132
using TP = _8;
133133
static constexpr int TP_ = TP{};
134134

135-
#if defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && \
135+
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && \
136136
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6))
137137

138138
// Distributed GEMM tiling/sharding schedule
@@ -252,7 +252,7 @@ HostTensorB tensor_B_arr[TP_];
252252
HostTensorD tensor_C_arr[TP_];
253253
HostTensorD tensor_D_arr[TP_];
254254

255-
#endif // (defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) &&
255+
#endif // (defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) &&
256256
// (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6))
257257

258258
/////////////////////////////////////////////////////////////////////////////////////////////////
@@ -345,8 +345,7 @@ struct Result {
345345

346346
};
347347

348-
#if defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && \
349-
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6))
348+
#if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6))
350349

351350
/////////////////////////////////////////////////////////////////////////////////////////////////
352351
/// GEMM setup and evaluation
@@ -804,8 +803,7 @@ int run(Options &options) {
804803
return 0;
805804
}
806805

807-
#endif // (defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) &&
808-
// (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6))
806+
#endif //(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6))
809807

810808
///////////////////////////////////////////////////////////////////////////////////////////////////
811809

@@ -859,7 +857,7 @@ int main(int argc, char const **args) {
859857
// Evaluate CUTLASS kernels
860858
//
861859

862-
#if (defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6)))
860+
#if ((__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6)))
863861
run(options);
864862
#else
865863
std::cerr

examples/82_blackwell_distributed_gemm/82_blackwell_distributed_gemm.cu

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ using namespace cute;
132132
using TP = _8;
133133
static constexpr int TP_ = TP{};
134134

135-
#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && \
135+
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) && \
136136
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8))
137137

138138
// Distributed GEMM tiling/sharding schedule
@@ -254,7 +254,7 @@ HostTensorB tensor_B_arr[TP_];
254254
HostTensorD tensor_C_arr[TP_];
255255
HostTensorD tensor_D_arr[TP_];
256256

257-
#endif // (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) &&
257+
#endif // (defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) &&
258258
// (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8))
259259

260260
/////////////////////////////////////////////////////////////////////////////////////////////////
@@ -347,8 +347,7 @@ struct Result {
347347

348348
};
349349

350-
#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && \
351-
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8))
350+
#if ((__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8)))
352351

353352
/////////////////////////////////////////////////////////////////////////////////////////////////
354353
/// GEMM setup and evaluation
@@ -812,8 +811,7 @@ int run(Options &options) {
812811
return 0;
813812
}
814813

815-
#endif // (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) &&
816-
// (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8))
814+
#endif // (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8))
817815

818816
///////////////////////////////////////////////////////////////////////////////////////////////////
819817

@@ -867,7 +865,7 @@ int main(int argc, char const **args) {
867865
// Evaluate CUTLASS kernels
868866
//
869867

870-
#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8)))
868+
#if ((__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8)))
871869
run(options);
872870
#else
873871
std::cerr

0 commit comments

Comments
 (0)