Skip to content

Commit c24f8a8

Browse files
liuhao2638claude
andcommitted
fix: skip llm_int8_linear on V100, fix sparse_attention columns shape
- test_ai_quantized_linear: Add _is_ampere_or_above() check to skip TestLlmInt8Linear tests on GPUs with compute capability < 8.0 (CI V100 has sm_70, cublasLtMatmul returns CUBLAS_STATUS_NOT_SUPPORTED) - test_ai_sparse_attention: Fix columns tensor shape from total_nnz (B*H*S*nnz_per_row) to per_head_nnz (S*nnz_per_row), resolving cusparse dimension mismatch (nnz > matrix_size) and segfault Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent f05eb90 commit c24f8a8

File tree

2 files changed

+26
-10
lines changed

2 files changed

+26
-10
lines changed

test/ai_edited_test/test_ai_quantized_linear.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,18 @@ def test_weight_only_linear_2d_input(self):
349349
self.skipTest(f"Unsupported arch or CUDA error: {e}")
350350

351351

352+
def _is_ampere_or_above():
353+
"""Check if GPU compute capability >= 8.0 (Ampere+).
354+
llm_int8_linear requires Ampere or newer architecture."""
355+
if not paddle.is_compiled_with_cuda():
356+
return False
357+
try:
358+
arch = _get_arch_info()
359+
return arch >= 80
360+
except (ValueError, RuntimeError):
361+
return False
362+
363+
352364
class TestLlmInt8Linear(unittest.TestCase):
353365
"""Test llm_int8_linear function.
354366
测试 llm_int8_linear 函数。"""
@@ -357,7 +369,8 @@ def setUp(self):
357369
paddle.disable_static()
358370

359371
@unittest.skipIf(
360-
not paddle.is_compiled_with_cuda(), "CUDA required for llm_int8_linear"
372+
not _is_ampere_or_above(),
373+
"llm_int8_linear requires Ampere+ (sm_80), skipped on CI V100 (sm_70)",
361374
)
362375
def test_llm_int8_linear_basic(self):
363376
"""Test basic llm_int8_linear without bias.
@@ -373,7 +386,8 @@ def test_llm_int8_linear_basic(self):
373386
self.skipTest(f"CUDA error: {e}")
374387

375388
@unittest.skipIf(
376-
not paddle.is_compiled_with_cuda(), "CUDA required for llm_int8_linear"
389+
not _is_ampere_or_above(),
390+
"llm_int8_linear requires Ampere+ (sm_80), skipped on CI V100 (sm_70)",
377391
)
378392
def test_llm_int8_linear_with_bias(self):
379393
"""Test llm_int8_linear with bias.
@@ -391,7 +405,8 @@ def test_llm_int8_linear_with_bias(self):
391405
self.skipTest(f"CUDA error: {e}")
392406

393407
@unittest.skipIf(
394-
not paddle.is_compiled_with_cuda(), "CUDA required for llm_int8_linear"
408+
not _is_ampere_or_above(),
409+
"llm_int8_linear requires Ampere+ (sm_80), skipped on CI V100 (sm_70)",
395410
)
396411
def test_llm_int8_linear_different_threshold(self):
397412
"""Test llm_int8_linear with different threshold.
@@ -406,7 +421,8 @@ def test_llm_int8_linear_different_threshold(self):
406421
self.skipTest(f"CUDA error: {e}")
407422

408423
@unittest.skipIf(
409-
not paddle.is_compiled_with_cuda(), "CUDA required for llm_int8_linear"
424+
not _is_ampere_or_above(),
425+
"llm_int8_linear requires Ampere+ (sm_80), skipped on CI V100 (sm_70)",
410426
)
411427
def test_llm_int8_linear_high_threshold(self):
412428
"""Test llm_int8_linear with high threshold (fewer outliers).

test/ai_edited_test/test_ai_sparse_attention.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def test_sparse_attention_output_shape(self):
8282
)
8383
# Each position attends to 2 positions
8484
nnz_per_row = 2
85-
total_nnz = batch_size * num_heads * seq_len * nnz_per_row
85+
nnz_per_head = seq_len * nnz_per_row
8686
offset = paddle.zeros(
8787
[batch_size, num_heads, seq_len + 1], dtype="int32"
8888
)
@@ -91,7 +91,7 @@ def test_sparse_attention_output_shape(self):
9191
for s in range(seq_len):
9292
offset[b, h, s + 1] = offset[b, h, s] + nnz_per_row
9393
columns = paddle.zeros(
94-
[batch_size, num_heads, total_nnz], dtype="int32"
94+
[batch_size, num_heads, nnz_per_head], dtype="int32"
9595
)
9696
# Each position attends to itself and the next position
9797
for b in range(batch_size):
@@ -263,11 +263,11 @@ def test_sparse_attention_multi_head(self):
263263
value = paddle.randn([batch, heads, seq, dim], dtype="float32")
264264
# Dense pattern: each row attends to all 4 positions
265265
nnz_per_row = 4
266+
nnz_per_head = seq * nnz_per_row
266267
offset = paddle.zeros([batch, heads, seq + 1], dtype="int32")
267268
for s in range(seq):
268269
offset[0, :, s + 1] = offset[0, :, s] + nnz_per_row
269-
total_nnz = batch * heads * seq * nnz_per_row
270-
columns = paddle.zeros([batch, heads, total_nnz], dtype="int32")
270+
columns = paddle.zeros([batch, heads, nnz_per_head], dtype="int32")
271271
for h in range(heads):
272272
for s in range(seq):
273273
base = offset[0, h, s].item()
@@ -323,11 +323,11 @@ def test_sparse_attention_different_head_dim(self):
323323
key = paddle.randn([1, 2, 4, head_dim], dtype="float32")
324324
value = paddle.randn([1, 2, 4, head_dim], dtype="float32")
325325
nnz_per_row = 4
326+
nnz_per_head = 4 * nnz_per_row
326327
offset = paddle.zeros([1, 2, 5], dtype="int32")
327328
for s in range(4):
328329
offset[0, :, s + 1] = offset[0, :, s] + nnz_per_row
329-
total_nnz = 1 * 2 * 4 * nnz_per_row
330-
columns = paddle.zeros([1, 2, total_nnz], dtype="int32")
330+
columns = paddle.zeros([1, 2, nnz_per_head], dtype="int32")
331331
for h in range(2):
332332
for s in range(4):
333333
base = offset[0, h, s].item()

0 commit comments

Comments
 (0)