Skip to content

[Bug] storage_rewrite pass fails when using T.alloc_shared #481

@fengz72

Description

@fengz72

Scenario

V-Core computation using T.alloc_shared for memory allocation.

Error Message

Traceback (most recent call last):
  File "/data/h00909914/workspace/tilelang/tilelang-ascend/rope/rope_setvalue_sync.py", line 187, in <module>
    tilelang_apply_rope_partial_in_place(x_tl, sin, cos)
  File "/data/h00909914/workspace/tilelang/tilelang-ascend/rope/rope_setvalue_sync.py", line 118, in tilelang_apply_rope_partial_in_place
    kernel = rope_kernel_in_place(
             ^^^^^^^^^^^^^^^^^^^^^
  File "/data/h00909914/workspace/tilelang/tilelang-ascend/tilelang/jit/__init__.py", line 218, in wrapper
    kernel_result = compile(
                    ^^^^^^^^
  File "/data/h00909914/workspace/tilelang/tilelang-ascend/tilelang/jit/__init__.py", line 82, in compile
    return cached(
           ^^^^^^^
  File "/data/h00909914/workspace/tilelang/tilelang-ascend/tilelang/cache/__init__.py", line 35, in cached
    return _kernel_cache_instance.cached(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/h00909914/workspace/tilelang/tilelang-ascend/tilelang/cache/kernel_cache.py", line 147, in cached
    return JITKernel(
           ^^^^^^^^^^
  File "/data/h00909914/workspace/tilelang/tilelang-ascend/tilelang/jit/kernel.py", line 118, in __init__
    adapter = self._compile_and_create_adapter(func, out_idx, workspace_idx)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/h00909914/workspace/tilelang/tilelang-ascend/tilelang/jit/kernel.py", line 225, in _compile_and_create_adapter
    artifact = tilelang.lower(
               ^^^^^^^^^^^^^^^
  File "/data/h00909914/workspace/tilelang/tilelang-ascend/tilelang/engine/lower.py", line 231, in lower
    mod = OptimizeForTarget(mod, target, platform)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/h00909914/workspace/tilelang/tilelang-ascend/tilelang/engine/phase.py", line 93, in OptimizeForTarget
    mod = tilelang.transform.AscendStorageRewrite(is_npu=check_npu_availability())(mod)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/h00909914/workspace/tilelang/tilelang-ascend/tilelang/../3rdparty/tvm/python/tvm/ir/transform.py", line 238, in __call__
    return _ffi_transform_api.RunPass(self, mod)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/h00909914/workspace/tilelang/tilelang-ascend/tilelang/../3rdparty/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
    raise_last_ffi_error()
  File "/data/h00909914/workspace/tilelang/tilelang-ascend/tilelang/../3rdparty/tvm/python/tvm/_ffi/base.py", line 465, in raise_last_ffi_error
    raise py_err
tvm.error.InternalError: Traceback (most recent call last):
  38: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::__mk_TVM9::{lambda(tvm::transform::Pass, tvm::IRModule)#1}>(tvm::transform::__mk_TVM9::{lambda(tvm::transform::Pass, tvm::IRModule)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::allocator<char>, tvm::runtime::TVMArgs const&)
  37: tvm::transform::Pass::operator()(tvm::IRModule) const
  36: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  35: tvm::tir::transform::PrimFuncPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  34: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFNS_3tir8PrimFuncES6_NS_8IRModuleENS_9transform11PassContextEEE17AssignTypedLambdaIZNS5_9transform20AscendStorageRewriteEbEUlS6_S7_S9_E_EEvT_EUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SG_SK_
  33: tvm::tir::LinearAccessPatternFinder::VisitStmt_(tvm::tir::AttrStmtNode const*)
  32: void tvm::tir::LinearAccessPatternFinder::VisitNewScope<tvm::tir::AttrStmtNode>(tvm::tir::AttrStmtNode const*)
  31: tvm::tir::LinearAccessPatternFinder::VisitStmt_(tvm::tir::AllocateNode const*)
  30: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::AllocateNode const*)
  29: tvm::tir::LinearAccessPatternFinder::VisitStmt_(tvm::tir::AllocateNode const*)
  28: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::AllocateNode const*)
  27: tvm::tir::LinearAccessPatternFinder::VisitStmt_(tvm::tir::AllocateNode const*)
  26: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::AllocateNode const*)
  25: tvm::tir::LinearAccessPatternFinder::VisitStmt_(tvm::tir::AllocateNode const*)
  24: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::AllocateNode const*)
  23: tvm::tir::LinearAccessPatternFinder::VisitStmt_(tvm::tir::AllocateNode const*)
  22: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::AllocateNode const*)
  21: tvm::tir::LinearAccessPatternFinder::VisitStmt_(tvm::tir::AllocateNode const*)
  20: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::AllocateNode const*)
  19: tvm::tir::LinearAccessPatternFinder::VisitStmt_(tvm::tir::AllocateNode const*)
  18: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::AllocateNode const*)
  17: tvm::tir::LinearAccessPatternFinder::VisitStmt_(tvm::tir::AllocateNode const*)
  16: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::AllocateNode const*)
  15: tvm::tir::LinearAccessPatternFinder::VisitStmt_(tvm::tir::AllocateNode const*)
  14: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::AllocateNode const*)
  13: tvm::tir::LinearAccessPatternFinder::VisitStmt_(tvm::tir::AllocateNode const*)
  12: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::AllocateNode const*)
  11: tvm::tir::LinearAccessPatternFinder::VisitStmt_(tvm::tir::AllocateNode const*)
  10: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::AllocateNode const*)
  9: tvm::tir::LinearAccessPatternFinder::VisitStmt_(tvm::tir::AllocateNode const*)
  8: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::AllocateNode const*)
  7: tvm::tir::LinearAccessPatternFinder::VisitStmt_(tvm::tir::AllocateNode const*)
  6: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::AllocateNode const*)
  5: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*)
  4: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*)
  3: tvm::tir::LinearAccessPatternFinder::VisitStmt_(tvm::tir::EvaluateNode const*)
  2: tvm::tir::ExprVisitor::VisitExpr_(tvm::tir::CallNode const*)
  1: tvm::tir::ExprVisitor::VisitExpr_(tvm::tir::CallNode const*)
  0: non-virtual thunk to tvm::tir::LinearAccessPatternFinder::VisitExpr_(tvm::tir::VarNode const*)
  File "/data/h00909914/workspace/tilelang/tilelang-ascend/3rdparty/tvm/src/tir/transforms/storage_rewrite.cc", line 198
InternalError: Check failed: it->second.level < scope_.size() (3 vs. 2) :  buf=mask_ub

Reproduction

Modify examples/pos_embedding/rms_rope_fused.py by changing T.alloc_ub to T.alloc_shared, then run the script.

Possible Cause

The resource_scope in the generated IR is incorrectly indented (or has incorrect scope nesting).

Image

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions