Skip to content

Commit 8315b19

Browse files
committed
Guard BF16/F8E4M3 fill tests for CUDA compute cap < 80
The fill_bf16 and fill_f8_e4m3 kernels require __CUDA_ARCH__ >= 800. On T4 GPUs (SM 7.5) these symbols don't exist, causing CUDA_ERROR_NOT_FOUND. Add supports_sm80_dtypes() helper that checks CUDA_COMPUTE_CAP env var and guards both BF16 and F8E4M3 tests.
1 parent 2bc5f3c commit 8315b19

1 file changed

Lines changed: 21 additions & 23 deletions

File tree

candle-core/tests/tensor_tests.rs

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,16 @@
11
use candle_core::{test_device, test_utils, DType, Device, IndexOp, Result, Tensor, D};
22
use float8::F8E4M3;
33

4-
/// Check whether fp8 (F8E4M3) is supported on the given device.
5-
/// fp8 requires CUDA compute capability >= 8.0 and is not supported on Metal.
6-
fn supports_fp8(device: &Device) -> bool {
7-
if device.is_metal() {
8-
return false;
9-
}
4+
/// Check whether BF16/F8E4M3 CUDA kernels (fill, affine, reduce, etc.) are available.
5+
/// These kernels require `__CUDA_ARCH__ >= 800` (Ampere+) and are not compiled for older GPUs.
6+
/// Metal does not support F8E4M3 at all; BF16 fill works on Metal.
7+
fn supports_sm80_dtypes(device: &Device) -> bool {
108
if device.is_cuda() {
11-
// CUDA_COMPUTE_CAP is set at build time and reflects the target SM arch.
129
if let Ok(cap) = std::env::var("CUDA_COMPUTE_CAP") {
1310
if let Ok(cap) = cap.parse::<u32>() {
1411
return cap >= 80;
1512
}
1613
}
17-
// If CUDA_COMPUTE_CAP is unset, assume SM >= 80 (optimistic default).
1814
}
1915
true
2016
}
@@ -65,23 +61,25 @@ fn ones(device: &Device) -> Result<()> {
6561
]
6662
],
6763
);
68-
assert_eq!(
69-
Tensor::ones((2, 3), DType::BF16, device)?.to_vec2::<half::bf16>()?,
70-
[
64+
if supports_sm80_dtypes(device) {
65+
assert_eq!(
66+
Tensor::ones((2, 3), DType::BF16, device)?.to_vec2::<half::bf16>()?,
7167
[
72-
half::bf16::from_f32(1.0),
73-
half::bf16::from_f32(1.0),
74-
half::bf16::from_f32(1.0)
68+
[
69+
half::bf16::from_f32(1.0),
70+
half::bf16::from_f32(1.0),
71+
half::bf16::from_f32(1.0)
72+
],
73+
[
74+
half::bf16::from_f32(1.0),
75+
half::bf16::from_f32(1.0),
76+
half::bf16::from_f32(1.0)
77+
]
7578
],
76-
[
77-
half::bf16::from_f32(1.0),
78-
half::bf16::from_f32(1.0),
79-
half::bf16::from_f32(1.0)
80-
]
81-
],
82-
);
79+
);
80+
}
8381

84-
if supports_fp8(device) {
82+
if supports_sm80_dtypes(device) && !device.is_metal() {
8583
assert_eq!(
8684
Tensor::ones((2, 3), DType::F8E4M3, device)?.to_vec2::<F8E4M3>()?,
8785
[
@@ -147,7 +145,7 @@ fn arange(device: &Device) -> Result<()> {
147145
[5, 4, 3, 2, 1],
148146
);
149147

150-
if supports_fp8(device) {
148+
if supports_sm80_dtypes(device) && !device.is_metal() {
151149
assert_eq!(
152150
Tensor::arange_step(
153151
F8E4M3::from_f32(0.),

0 commit comments

Comments
 (0)