Skip to content

[Issue]: Cached kernel is used for different shapes and resulting in incorrect outputs for vecAdd #317

@rahulbatra85

Description

@rahulbatra85

Problem Description

Cached kernel is used for different shapes and resulting in incorrect outputs for vecAdd. Turning off kernel caching produces correct result. Also, dumping the IR shows different shapes being used even input shape is different.****

Operating System

Ubuntu 24.04.3 LTS (Noble Numbat)

CPU

AMD EPYC 9575F 64-Core Processor

GPU

AMD Instinct MI355, gfx950

ROCm Version

ROCm 7.1.0

ROCm Component

No response

Steps to Reproduce

import sys
import os
import numpy as np
import pytest

import torch

import flydsl.compiler as flyc
import flydsl.expr as fx

import logging

logger = logging.getLogger("flydsl")

def checkAllclose(
    a, b, rtol=1e-2, atol=1e-2, tol_err_ratio=0.05, msg="", printNum=8, printLog=True
):
    isClose = torch.isclose(a, b, rtol=rtol, atol=atol)

    if isClose.all():
        if printLog:
            logger.info(f"{msg}[checkAllclose {atol=} {rtol=} \033[32mpassed~\033[0m]")
        return 0
    else:
        try:
            mask = ~isClose
            num = mask.sum()
            printNum = min(printNum, num)
            percent = (num / a.numel()).item()
            if not printLog:
                return percent
            a_msked = a[mask]
            b_msked = b[mask]
            delta = (a_msked - b_msked).abs()
        except RuntimeError as e:
            mask = ~isClose.to("cpu")
            num = mask.sum()
            printNum = min(printNum, num)
            percent = (num / a.numel()).item()
            if not printLog:
                return percent
            a_msked = a[mask]
            b_msked = b[mask]
            delta = (a_msked - b_msked).abs()
        if percent > tol_err_ratio:
            logger.info(
                f"""{msg}[checkAllclose {atol=} {rtol=} \033[31mfailed!\033[0m]
    a    : {a.shape}
           {a_msked[:printNum]}
    b    : {b.shape}
           {b_msked[:printNum]}
    delta:
           {delta[:printNum]}"""
            )
        else:
            logger.info(
                f"""{msg}[checkAllclose {atol=} {rtol=} \033[33mwarning!\033[0m] a and b results are not all close"""
            )
        logger.info(
            f"-->max abs delta:{delta.max()}, delta details: {percent:.1%} ({num} of {a.numel()}) elements"
        )
        return percent


@flyc.kernel
def vecAddKernel(
    A: fx.Tensor,
    B: fx.Tensor,
    C: fx.Tensor,
    block_dim: fx.Constexpr[int],
    vec_width: fx.Constexpr[int],
):
    bid = fx.block_idx.x
    tid = fx.thread_idx.x

    tile_elems = block_dim * vec_width

    tA = fx.logical_divide(A, fx.make_layout(tile_elems, 1))
    tB = fx.logical_divide(B, fx.make_layout(tile_elems, 1))
    tC = fx.logical_divide(C, fx.make_layout(tile_elems, 1))

    tA = fx.slice(tA, (None, bid))
    tB = fx.slice(tB, (None, bid))
    tC = fx.slice(tC, (None, bid))

    tA = fx.logical_divide(tA, fx.make_layout(vec_width, 1))
    tB = fx.logical_divide(tB, fx.make_layout(vec_width, 1))
    tC = fx.logical_divide(tC, fx.make_layout(vec_width, 1))

    copy_bits = vec_width * 32
    RABMemRefTy = fx.MemRefType.get(
        fx.T.f32(), fx.LayoutType.get(vec_width, 1), fx.AddressSpace.Register
    )
    copyAtom = fx.make_copy_atom(fx.UniversalCopy(copy_bits), fx.Float32)

    rA = fx.memref_alloca(RABMemRefTy, fx.make_layout(vec_width, 1))
    rB = fx.memref_alloca(RABMemRefTy, fx.make_layout(vec_width, 1))
    rC = fx.memref_alloca(RABMemRefTy, fx.make_layout(vec_width, 1))

    fx.copy_atom_call(copyAtom, fx.slice(tA, (None, tid)), rA)
    fx.copy_atom_call(copyAtom, fx.slice(tB, (None, tid)), rB)

    vC = fx.arith.addf(fx.memref_load_vec(rA), fx.memref_load_vec(rB))
    fx.memref_store_vec(vC, rC)

    fx.copy_atom_call(copyAtom, rC, fx.slice(tC, (None, tid)))


@flyc.jit
def vecAdd(
    A: fx.Tensor,
    B: fx.Tensor,
    C,
    n: fx.Int32,
    block_dim: fx.Constexpr[int],
    vec_width: fx.Constexpr[int],
    stream: fx.Stream = fx.Stream(None),
):
    tile_elems = block_dim * vec_width
    grid_x = (n + tile_elems - 1) // tile_elems
    vecAddKernel(A, B, C, block_dim, vec_width).launch(
        grid=(grid_x, 1, 1), block=(block_dim, 1, 1), stream=stream
    )


def checkCorrectness(N, dtype=torch.bfloat16, vec_width=8):
    A = torch.ones(N,1, dtype=dtype).cuda()
    B = torch.ones(N,1, dtype=dtype).cuda()
    C = torch.zeros(N,1, dtype=dtype).cuda()

    stream = torch.cuda.Stream()
    THREADS_PER_BLOCK =  256 
    VEC_WIDTH = vec_width
    TILE_ELEMS = THREADS_PER_BLOCK * VEC_WIDTH

    vecAdd(A, B, C, N, THREADS_PER_BLOCK, VEC_WIDTH, stream=stream)
    #print(C)
    torch.cuda.synchronize()
    
    error = checkAllclose(C, A + B)
    print(f"  Correctness: max error = {error:.2e}")
    if error != 0:
        print(" Correctness check FAILED!!!")


N = 512 
checkCorrectness(N, torch.float32)
N = 4096*4096
checkCorrectness(N, torch.float32)
N = 4096*16384
checkCorrectness(N, torch.float32)

Running with default kernel caching fails the correctness check
python3 test_vec_add.py
Correctness: max error = 0.00e+00
Correctness: max error = 1.00e+00
Correctness check FAILED!!!
Correctness: max error = 1.00e+00
Correctness check FAILED!!!

FLYDSL_RUNTIME_ENABLE_CACHE=0 python3 test_vec_add.py
Correctness: max error = 0.00e+00
Correctness: max error = 0.00e+00
Correctness: max error = 0.00e+00

(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support

No response

Additional Information

No response

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions