File tree Expand file tree Collapse file tree 3 files changed +48
-21
lines changed
fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions Expand file tree Collapse file tree 3 files changed +48
-21
lines changed Original file line number Diff line number Diff line change 8
8
9
9
#include < ATen/ATen.h>
10
10
#include < ATen/cuda/CUDAContext.h>
11
+ #include < c10/cuda/CUDAGuard.h>
11
12
// clang-format on
12
13
13
14
#include " f8f8bf16_rowwise/f8f8bf16_rowwise_manifest.cuh"
@@ -28,12 +29,21 @@ at::Tensor dispatch_fp8_rowwise_kernel(
28
29
int M = size_to_dim_ (XQ.dim () - 1 , XQ.sizes ());
29
30
int N = size_to_dim_ (WQ.dim () - 1 , WQ.sizes ());
30
31
int K = XQ.size (-1 );
31
-
32
- int arch = 9 ;
33
- cudaDeviceProp prop;
34
- cudaGetDeviceProperties (&prop, 0 );
35
- if (prop.major >= 10 ) {
36
- arch = 10 ;
32
+ static int arch = -1 ;
33
+ // Avoid expensive cudaGetDeviceProperties call.
34
+ if (arch < 0 ) {
35
+ cudaDeviceProp prop;
36
+ cudaGetDeviceProperties (&prop, 0 );
37
+ if (prop.major >= 10 ) {
38
+ arch = 10 ;
39
+ int runtimeVersion;
40
+ C10_CUDA_CHECK (cudaRuntimeGetVersion (&runtimeVersion));
41
+ TORCH_CHECK (
42
+ runtimeVersion >= 12080 ,
43
+ " FP8 GEMM on sm100a or above requires cuda >= 12.8" );
44
+ } else {
45
+ arch = 9 ;
46
+ }
37
47
}
38
48
39
49
// Use shape heuristics to dispatch to optimized kernel configuration.
Original file line number Diff line number Diff line change 6
6
* LICENSE file in the root directory of this source tree.
7
7
*/
8
8
9
+ #include < c10/cuda/CUDAGuard.h>
9
10
#include < cute/tensor.hpp>
10
11
#include " f8f8bf16_rowwise_batched/f8f8bf16_rowwise_batched_manifest.cuh"
11
12
@@ -29,11 +30,21 @@ at::Tensor dispatch_fp8_rowwise_batched_kernel(
29
30
bool use_fast_accum = true ,
30
31
std::optional<at::Tensor> bias = std::nullopt,
31
32
std::optional<at::Tensor> output = std::nullopt) {
32
- int arch = 9 ;
33
- cudaDeviceProp prop;
34
- cudaGetDeviceProperties (&prop, 0 );
35
- if (prop.major >= 10 ) {
36
- arch = 10 ;
33
+ static int arch = -1 ;
34
+ // Avoid expensive cudaGetDeviceProperties call.
35
+ if (arch < 0 ) {
36
+ cudaDeviceProp prop;
37
+ cudaGetDeviceProperties (&prop, 0 );
38
+ if (prop.major >= 10 ) {
39
+ arch = 10 ;
40
+ int runtimeVersion;
41
+ C10_CUDA_CHECK (cudaRuntimeGetVersion (&runtimeVersion));
42
+ TORCH_CHECK (
43
+ runtimeVersion >= 12080 ,
44
+ " FP8 batched GEMM on sm100a or above requires cuda >= 12.8" );
45
+ } else {
46
+ arch = 9 ;
47
+ }
37
48
}
38
49
39
50
TORCH_CHECK (
Original file line number Diff line number Diff line change 8
8
9
9
#include < ATen/ATen.h>
10
10
#include < ATen/cuda/CUDAContext.h>
11
+ #include < c10/cuda/CUDAGuard.h>
11
12
// clang-format on
12
13
13
14
#include " f8f8bf16_rowwise_grouped/f8f8bf16_rowwise_grouped_manifest.cuh"
@@ -30,16 +31,21 @@ at::Tensor dispatch_fp8_grouped_kernel(
30
31
at::Tensor output,
31
32
std::optional<at::Tensor> zero_start_index_M = std::nullopt,
32
33
std::optional<at::Tensor> M_sizes = std::nullopt) {
33
- int arch = 9 ;
34
- cudaDeviceProp prop;
35
- cudaGetDeviceProperties (&prop, 0 );
36
- if (prop.major >= 10 ) {
37
- arch = 10 ;
38
- int runtimeVersion;
39
- cudaRuntimeGetVersion (&runtimeVersion);
40
- TORCH_CHECK (
41
- runtimeVersion >= 12080 ,
42
- " FP8 grouped GEMM on blackwell sm100a requires cuda >= 12.8" );
34
+ static int arch = -1 ;
35
+ // Avoid expensive cudaGetDeviceProperties call.
36
+ if (arch < 0 ) {
37
+ cudaDeviceProp prop;
38
+ cudaGetDeviceProperties (&prop, 0 );
39
+ if (prop.major >= 10 ) {
40
+ arch = 10 ;
41
+ int runtimeVersion;
42
+ C10_CUDA_CHECK (cudaRuntimeGetVersion (&runtimeVersion));
43
+ TORCH_CHECK (
44
+ runtimeVersion >= 12080 ,
45
+ " FP8 grouped GEMM on sm100a or above requires cuda >= 12.8" );
46
+ } else {
47
+ arch = 9 ;
48
+ }
43
49
}
44
50
45
51
// Use heuristics to pick the best kernel implementation.
You can’t perform that action at this time.
0 commit comments