Cached kernel is used for different shapes and resulting in incorrect outputs for vecAdd. Turning off kernel caching produces correct result. Also, dumping the IR shows different shapes being used even input shape is different.****
import sys
import os
import numpy as np
import pytest
import torch
import flydsl.compiler as flyc
import flydsl.expr as fx
import logging
logger = logging.getLogger("flydsl")
def checkAllclose(
a, b, rtol=1e-2, atol=1e-2, tol_err_ratio=0.05, msg="", printNum=8, printLog=True
):
isClose = torch.isclose(a, b, rtol=rtol, atol=atol)
if isClose.all():
if printLog:
logger.info(f"{msg}[checkAllclose {atol=} {rtol=} \033[32mpassed~\033[0m]")
return 0
else:
try:
mask = ~isClose
num = mask.sum()
printNum = min(printNum, num)
percent = (num / a.numel()).item()
if not printLog:
return percent
a_msked = a[mask]
b_msked = b[mask]
delta = (a_msked - b_msked).abs()
except RuntimeError as e:
mask = ~isClose.to("cpu")
num = mask.sum()
printNum = min(printNum, num)
percent = (num / a.numel()).item()
if not printLog:
return percent
a_msked = a[mask]
b_msked = b[mask]
delta = (a_msked - b_msked).abs()
if percent > tol_err_ratio:
logger.info(
f"""{msg}[checkAllclose {atol=} {rtol=} \033[31mfailed!\033[0m]
a : {a.shape}
{a_msked[:printNum]}
b : {b.shape}
{b_msked[:printNum]}
delta:
{delta[:printNum]}"""
)
else:
logger.info(
f"""{msg}[checkAllclose {atol=} {rtol=} \033[33mwarning!\033[0m] a and b results are not all close"""
)
logger.info(
f"-->max abs delta:{delta.max()}, delta details: {percent:.1%} ({num} of {a.numel()}) elements"
)
return percent
@flyc.kernel
def vecAddKernel(
A: fx.Tensor,
B: fx.Tensor,
C: fx.Tensor,
block_dim: fx.Constexpr[int],
vec_width: fx.Constexpr[int],
):
bid = fx.block_idx.x
tid = fx.thread_idx.x
tile_elems = block_dim * vec_width
tA = fx.logical_divide(A, fx.make_layout(tile_elems, 1))
tB = fx.logical_divide(B, fx.make_layout(tile_elems, 1))
tC = fx.logical_divide(C, fx.make_layout(tile_elems, 1))
tA = fx.slice(tA, (None, bid))
tB = fx.slice(tB, (None, bid))
tC = fx.slice(tC, (None, bid))
tA = fx.logical_divide(tA, fx.make_layout(vec_width, 1))
tB = fx.logical_divide(tB, fx.make_layout(vec_width, 1))
tC = fx.logical_divide(tC, fx.make_layout(vec_width, 1))
copy_bits = vec_width * 32
RABMemRefTy = fx.MemRefType.get(
fx.T.f32(), fx.LayoutType.get(vec_width, 1), fx.AddressSpace.Register
)
copyAtom = fx.make_copy_atom(fx.UniversalCopy(copy_bits), fx.Float32)
rA = fx.memref_alloca(RABMemRefTy, fx.make_layout(vec_width, 1))
rB = fx.memref_alloca(RABMemRefTy, fx.make_layout(vec_width, 1))
rC = fx.memref_alloca(RABMemRefTy, fx.make_layout(vec_width, 1))
fx.copy_atom_call(copyAtom, fx.slice(tA, (None, tid)), rA)
fx.copy_atom_call(copyAtom, fx.slice(tB, (None, tid)), rB)
vC = fx.arith.addf(fx.memref_load_vec(rA), fx.memref_load_vec(rB))
fx.memref_store_vec(vC, rC)
fx.copy_atom_call(copyAtom, rC, fx.slice(tC, (None, tid)))
@flyc.jit
def vecAdd(
A: fx.Tensor,
B: fx.Tensor,
C,
n: fx.Int32,
block_dim: fx.Constexpr[int],
vec_width: fx.Constexpr[int],
stream: fx.Stream = fx.Stream(None),
):
tile_elems = block_dim * vec_width
grid_x = (n + tile_elems - 1) // tile_elems
vecAddKernel(A, B, C, block_dim, vec_width).launch(
grid=(grid_x, 1, 1), block=(block_dim, 1, 1), stream=stream
)
def checkCorrectness(N, dtype=torch.bfloat16, vec_width=8):
A = torch.ones(N,1, dtype=dtype).cuda()
B = torch.ones(N,1, dtype=dtype).cuda()
C = torch.zeros(N,1, dtype=dtype).cuda()
stream = torch.cuda.Stream()
THREADS_PER_BLOCK = 256
VEC_WIDTH = vec_width
TILE_ELEMS = THREADS_PER_BLOCK * VEC_WIDTH
vecAdd(A, B, C, N, THREADS_PER_BLOCK, VEC_WIDTH, stream=stream)
#print(C)
torch.cuda.synchronize()
error = checkAllclose(C, A + B)
print(f" Correctness: max error = {error:.2e}")
if error != 0:
print(" Correctness check FAILED!!!")
N = 512
checkCorrectness(N, torch.float32)
N = 4096*4096
checkCorrectness(N, torch.float32)
N = 4096*16384
checkCorrectness(N, torch.float32)
Problem Description
Cached kernel is used for different shapes and resulting in incorrect outputs for vecAdd. Turning off kernel caching produces correct result. Also, dumping the IR shows different shapes being used even input shape is different.****
Operating System
Ubuntu 24.04.3 LTS (Noble Numbat)
CPU
AMD EPYC 9575F 64-Core Processor
GPU
AMD Instinct MI355, gfx950
ROCm Version
ROCm 7.1.0
ROCm Component
No response
Steps to Reproduce
Running with default kernel caching fails the correctness check
python3 test_vec_add.py
Correctness: max error = 0.00e+00
Correctness: max error = 1.00e+00
Correctness check FAILED!!!
Correctness: max error = 1.00e+00
Correctness check FAILED!!!
FLYDSL_RUNTIME_ENABLE_CACHE=0 python3 test_vec_add.py
Correctness: max error = 0.00e+00
Correctness: max error = 0.00e+00
Correctness: max error = 0.00e+00
(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support
No response
Additional Information
No response