|
1 | 1 | use candle_core::{test_device, test_utils, DType, Device, IndexOp, Result, Tensor, D}; |
2 | 2 | use float8::F8E4M3; |
3 | 3 |
|
4 | | -/// Check whether fp8 (F8E4M3) is supported on the given device. |
5 | | -/// fp8 requires CUDA compute capability >= 8.0 and is not supported on Metal. |
6 | | -fn supports_fp8(device: &Device) -> bool { |
7 | | - if device.is_metal() { |
8 | | - return false; |
9 | | - } |
| 4 | +/// Check whether BF16/F8E4M3 CUDA kernels (fill, affine, reduce, etc.) are available. |
| 5 | +/// These kernels require `__CUDA_ARCH__ >= 800` (Ampere+) and are not compiled for older GPUs. |
| 6 | +/// Metal does not support F8E4M3 at all; BF16 fill works on Metal. |
| 7 | +fn supports_sm80_dtypes(device: &Device) -> bool { |
10 | 8 | if device.is_cuda() { |
11 | | - // CUDA_COMPUTE_CAP is set at build time and reflects the target SM arch. |
12 | 9 | if let Ok(cap) = std::env::var("CUDA_COMPUTE_CAP") { |
13 | 10 | if let Ok(cap) = cap.parse::<u32>() { |
14 | 11 | return cap >= 80; |
15 | 12 | } |
16 | 13 | } |
17 | | - // If CUDA_COMPUTE_CAP is unset, assume SM >= 80 (optimistic default). |
18 | 14 | } |
19 | 15 | true |
20 | 16 | } |
@@ -65,23 +61,25 @@ fn ones(device: &Device) -> Result<()> { |
65 | 61 | ] |
66 | 62 | ], |
67 | 63 | ); |
68 | | - assert_eq!( |
69 | | - Tensor::ones((2, 3), DType::BF16, device)?.to_vec2::<half::bf16>()?, |
70 | | - [ |
| 64 | + if supports_sm80_dtypes(device) { |
| 65 | + assert_eq!( |
| 66 | + Tensor::ones((2, 3), DType::BF16, device)?.to_vec2::<half::bf16>()?, |
71 | 67 | [ |
72 | | - half::bf16::from_f32(1.0), |
73 | | - half::bf16::from_f32(1.0), |
74 | | - half::bf16::from_f32(1.0) |
| 68 | + [ |
| 69 | + half::bf16::from_f32(1.0), |
| 70 | + half::bf16::from_f32(1.0), |
| 71 | + half::bf16::from_f32(1.0) |
| 72 | + ], |
| 73 | + [ |
| 74 | + half::bf16::from_f32(1.0), |
| 75 | + half::bf16::from_f32(1.0), |
| 76 | + half::bf16::from_f32(1.0) |
| 77 | + ] |
75 | 78 | ], |
76 | | - [ |
77 | | - half::bf16::from_f32(1.0), |
78 | | - half::bf16::from_f32(1.0), |
79 | | - half::bf16::from_f32(1.0) |
80 | | - ] |
81 | | - ], |
82 | | - ); |
| 79 | + ); |
| 80 | + } |
83 | 81 |
|
84 | | - if supports_fp8(device) { |
| 82 | + if supports_sm80_dtypes(device) && !device.is_metal() { |
85 | 83 | assert_eq!( |
86 | 84 | Tensor::ones((2, 3), DType::F8E4M3, device)?.to_vec2::<F8E4M3>()?, |
87 | 85 | [ |
@@ -147,7 +145,7 @@ fn arange(device: &Device) -> Result<()> { |
147 | 145 | [5, 4, 3, 2, 1], |
148 | 146 | ); |
149 | 147 |
|
150 | | - if supports_fp8(device) { |
| 148 | + if supports_sm80_dtypes(device) && !device.is_metal() { |
151 | 149 | assert_eq!( |
152 | 150 | Tensor::arange_step( |
153 | 151 | F8E4M3::from_f32(0.), |
|
0 commit comments