diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 43dae5cb8da4..205e951592fb 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -49,22 +49,22 @@ jobs: run: >- conda build --output-folder=conda/pkg conda/recipe && conda install tvm -c ./conda/pkg - - name: Build iOS RPC - run: | - IOS_VERSION="14.0" - CMAKE_FLAGS="-DCMAKE_BUILD_TYPE=Release \ - -DCMAKE_SYSTEM_NAME=iOS \ - -DCMAKE_SYSTEM_VERSION=${IOS_VERSION} \ - -DCMAKE_OSX_SYSROOT=iphonesimulator \ - -DCMAKE_OSX_ARCHITECTURES=x86_64 \ - -DCMAKE_OSX_DEPLOYMENT_TARGET=14.0 \ - -DCMAKE_BUILD_WITH_INSTALL_NAME_DIR=ON \ - -DUSE_IOS_RPC=ON" - - mkdir build-ios-simulator - cd build-ios-simulator - cmake .. ${CMAKE_FLAGS} - cmake --build . --target ios_rpc +# - name: Build iOS RPC +# run: | +# IOS_VERSION="14.0" +# CMAKE_FLAGS="-DCMAKE_BUILD_TYPE=Release \ +# -DCMAKE_SYSTEM_NAME=iOS \ +# -DCMAKE_SYSTEM_VERSION=${IOS_VERSION} \ +# -DCMAKE_OSX_SYSROOT=iphonesimulator \ +# -DCMAKE_OSX_ARCHITECTURES=x86_64 \ +# -DCMAKE_OSX_DEPLOYMENT_TARGET=14.0 \ +# -DCMAKE_BUILD_WITH_INSTALL_NAME_DIR=ON \ +# -DUSE_IOS_RPC=ON" +# +# mkdir build-ios-simulator +# cd build-ios-simulator +# cmake .. ${CMAKE_FLAGS} +# cmake --build . --target ios_rpc - name: Test shell: bash -l {0} run: >- @@ -108,74 +108,39 @@ jobs: run: >- python -m pytest -v tests/python/all-platform-minimal-test - # Disabled due to https://github.com/apache/tvm/issues/13950 - # Windows-Static-Runtime: - # if: ${{ github.repository == 'apache/tvm' }} - # runs-on: windows-2019 - # steps: - # - uses: actions/checkout@v2 - # with: - # submodules: 'recursive' - # - name: Set up environment - # uses: ./.github/actions/setup - # - name: Build static TVM runtime - # shell: bash -l {0} - # run: | - # tests/scripts/task_config_build_static.sh build - # cd build - # cmake .. -A x64 -DCMAKE_CONFIGURATION_TYPES="Release" - # cmake --build . --config Release --target runtime - - Linux-Static-Runtime: - if: ${{ github.repository == 'apache/tvm' }} - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v2 - with: - submodules: 'recursive' - - name: Set up environment - uses: ./.github/actions/setup - - name: Build static TVM runtime - shell: bash -l {0} - run: | - tests/scripts/task_config_build_static.sh build - cd build - cmake .. - cmake --build . --config Release --target runtime - - Android: - if: ${{ github.repository == 'apache/tvm' }} - runs-on: ubuntu-22.04 - steps: - - uses: actions/checkout@v2 - with: - submodules: 'recursive' - - name: Set up environment - uses: ./.github/actions/setup - - name: Set up java - uses: actions/setup-java@v3 - with: - distribution: 'zulu' - java-version: '11' - - name: Build TVM - shell: bash -l {0} - run: | - mkdir build - cd build - ../tests/scripts/task_config_build_jvm.sh . - cmake .. - make - - name: Build TVM4J - run: | - make jvmpkg - - name: Build android_rpc - working-directory: apps/android_rpc - run: | - set -eux - export PATH="${ANDROID_NDK_LATEST_HOME}:$PATH" - gradle clean build - - name: Upload android_rpc APK - uses: actions/upload-artifact@v4 - with: - name: android_rpc-debug.apk - path: ./apps/android_rpc/app/build/outputs/apk/debug/app-debug.apk + # Android: + # if: ${{ github.repository == 'apache/tvm' }} + # runs-on: ubuntu-22.04 + # steps: + # - uses: actions/checkout@v2 + # with: + # submodules: 'recursive' + # - name: Set up environment + # uses: ./.github/actions/setup + # - name: Set up java + # uses: actions/setup-java@v3 + # with: + # distribution: 'zulu' + # java-version: '11' + # - name: Build TVM + # shell: bash -l {0} + # run: | + # mkdir build + # cd build + # ../tests/scripts/task_config_build_jvm.sh . + # cmake .. + # make + # - name: Build TVM4J + # run: | + # make jvmpkg + # - name: Build android_rpc + # working-directory: apps/android_rpc + # run: | + # set -eux + # export PATH="${ANDROID_NDK_LATEST_HOME}:$PATH" + # gradle clean build + # - name: Upload android_rpc APK + # uses: actions/upload-artifact@v4 + # with: + # name: android_rpc-debug.apk + # path: ./apps/android_rpc/app/build/outputs/apk/debug/app-debug.apk diff --git a/.gitmodules b/.gitmodules index a1187967f77f..e8a48d99c2a2 100644 --- a/.gitmodules +++ b/.gitmodules @@ -31,3 +31,6 @@ [submodule "3rdparty/zlib"] path = 3rdparty/zlib url = https://github.com/madler/zlib.git +[submodule "ffi/3rdparty/dlpack"] + path = ffi/3rdparty/dlpack + url = https://github.com/dmlc/dlpack.git diff --git a/3rdparty/cutlass_fpA_intB_gemm b/3rdparty/cutlass_fpA_intB_gemm index fdef2307917e..bbccc75af117 160000 --- a/3rdparty/cutlass_fpA_intB_gemm +++ b/3rdparty/cutlass_fpA_intB_gemm @@ -1 +1 @@ -Subproject commit fdef2307917ec2c7cc5becc29fb95d77498484bd +Subproject commit bbccc75af117473f6de81905bd3314775f41636e diff --git a/CMakeLists.txt b/CMakeLists.txt index caad7fb02b1f..b45e5becf364 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -159,7 +159,6 @@ if(MSVC) add_definitions(-D_SCL_SECURE_NO_WARNINGS) add_definitions(-D_ENABLE_EXTENDED_ALIGNED_STORAGE) add_definitions(-DNOMINMAX) - # regeneration does not work well with msbuild custom rules. set(CMAKE_SUPPRESS_REGENERATION ON) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc") @@ -496,6 +495,9 @@ list(REMOVE_ITEM COMPILER_SRCS ${LIBINFO_FILE}) add_library(tvm_objs OBJECT ${COMPILER_SRCS}) add_library(tvm_runtime_objs OBJECT ${RUNTIME_SRCS}) add_library(tvm_libinfo_objs OBJECT ${LIBINFO_FILE}) +target_link_libraries(tvm_objs PUBLIC tvm_ffi_header) +target_link_libraries(tvm_runtime_objs PUBLIC tvm_ffi_header) +target_link_libraries(tvm_libinfo_objs PUBLIC tvm_ffi_header) include(GNUInstallDirs) if(NOT BUILD_DUMMY_LIBTVM) @@ -567,6 +569,9 @@ if(USE_IOS_RPC) add_subdirectory("apps/ios_rpc") endif() +add_subdirectory(ffi) + + if(TVM_DEBUG_WITH_ABI_CHANGE) message(STATUS "Building with debug code that may cause ABI changes...") target_compile_definitions(tvm_objs PRIVATE "TVM_DEBUG_WITH_ABI_CHANGE") @@ -602,6 +607,10 @@ endif() target_link_libraries(tvm PRIVATE ${TVM_RUNTIME_LINKER_LIBS}) target_link_libraries(tvm_runtime PRIVATE ${TVM_RUNTIME_LINKER_LIBS}) +target_link_libraries(tvm PUBLIC tvm_ffi_objs) +target_link_libraries(tvm_runtime PUBLIC tvm_ffi_objs) + + if(BUILD_FOR_HEXAGON AND DEFINED USE_HEXAGON_GTEST AND EXISTS ${USE_HEXAGON_GTEST}) include(FetchContent) FetchContent_Declare(googletest SOURCE_DIR "${USE_HEXAGON_GTEST}") @@ -639,6 +648,7 @@ if (HIDE_PRIVATE_SYMBOLS AND NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin") target_link_libraries(tvm_runtime PRIVATE ${HIDE_SYMBOLS_LINKER_FLAGS}) target_compile_definitions(tvm_allvisible PUBLIC $) target_compile_definitions(tvm_allvisible PRIVATE $) + target_link_libraries(tvm_allvisible PUBLIC tvm_ffi_objs) endif() # Create the `cpptest` target if we can find GTest. If not, we create dummy @@ -696,10 +706,10 @@ if(NOT DEFINED ENV{CONDA_BUILD}) endif() # Installation rules -install(TARGETS tvm EXPORT ${PROJECT_NAME}Targets DESTINATION lib${LIB_SUFFIX}) -install(TARGETS tvm_runtime EXPORT ${PROJECT_NAME}Targets DESTINATION lib${LIB_SUFFIX}) +install(TARGETS tvm DESTINATION lib${LIB_SUFFIX}) +install(TARGETS tvm_runtime DESTINATION lib${LIB_SUFFIX}) if(BUILD_FOR_HEXAGON AND DEFINED USE_HEXAGON_GTEST AND EXISTS ${USE_HEXAGON_GTEST}) - install(TARGETS gtest EXPORT ${PROJECT_NAME}Targets DESTINATION lib${LIB_SUFFIX}) + install(TARGETS gtest DESTINATION lib${LIB_SUFFIX}) endif() if (INSTALL_DEV) @@ -734,9 +744,9 @@ string(APPEND PROJECT_CONFIG_CONTENT "include(\"\${CMAKE_CURRENT_LIST_DIR}/${PROJECT_NAME}Targets.cmake\")") file(WRITE "${CMAKE_CURRENT_BINARY_DIR}/temp_config_file.cmake" ${PROJECT_CONFIG_CONTENT}) -install(EXPORT ${PROJECT_NAME}Targets - NAMESPACE ${PROJECT_NAME}:: - DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/${PROJECT_NAME}) +# install(EXPORT ${PROJECT_NAME}Targets +# NAMESPACE ${PROJECT_NAME}:: +# DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/${PROJECT_NAME}) # Create config for find_package() configure_package_config_file( @@ -750,9 +760,9 @@ install( # More target definitions if(MSVC) - target_compile_definitions(tvm_objs PRIVATE -DTVM_EXPORTS) - target_compile_definitions(tvm_libinfo_objs PRIVATE -DTVM_EXPORTS) - target_compile_definitions(tvm_runtime_objs PRIVATE -DTVM_EXPORTS) + target_compile_definitions(tvm_objs PRIVATE -DTVM_EXPORTS -DTVM_FFI_EXPORTS) + target_compile_definitions(tvm_libinfo_objs PRIVATE -DTVM_EXPORTS -DTVM_FFI_EXPORTS) + target_compile_definitions(tvm_runtime_objs PRIVATE -DTVM_EXPORTS -DTVM_FFI_EXPORTS) endif() set(TVM_IS_DEBUG_BUILD OFF) @@ -774,17 +784,8 @@ if(TVM_IS_DEBUG_BUILD) endif() endif() -# Run dsymutil to generate debugging symbols for backtraces -if(APPLE AND TVM_IS_DEBUG_BUILD) - find_program(DSYMUTIL dsymutil) - mark_as_advanced(DSYMUTIL) - add_custom_command(TARGET tvm - POST_BUILD - COMMAND ${DSYMUTIL} ARGS $ - COMMENT "Running dsymutil" - VERBATIM - ) -endif() +add_dsymutil(tvm) +add_dsymutil(tvm_runtime) if(BUILD_FOR_HEXAGON) # Wrap pthread_create to allow setting custom stack size. diff --git a/apps/android_rpc/app/src/main/jni/Android.mk b/apps/android_rpc/app/src/main/jni/Android.mk index ad9cee9bbdb5..692a3390131d 100644 --- a/apps/android_rpc/app/src/main/jni/Android.mk +++ b/apps/android_rpc/app/src/main/jni/Android.mk @@ -37,7 +37,8 @@ LOCAL_SRC_FILES := org_apache_tvm_native_c_api.cc LOCAL_LDFLAGS := -L$(SYSROOT)/usr/lib/ -llog LOCAL_C_INCLUDES := $(ROOT_PATH)/include \ - $(ROOT_PATH)/3rdparty/dlpack/include \ + $(ROOT_PATH)/ffi/include \ + $(ROOT_PATH)/ffi/3rdparty/dlpack/include \ $(ROOT_PATH)/3rdparty/dmlc-core/include \ $(ROOT_PATH)/3rdparty/OpenCL-Headers diff --git a/apps/android_rpc/app/src/main/jni/tvm_runtime.h b/apps/android_rpc/app/src/main/jni/tvm_runtime.h index fa9fde747892..26085bc366f4 100644 --- a/apps/android_rpc/app/src/main/jni/tvm_runtime.h +++ b/apps/android_rpc/app/src/main/jni/tvm_runtime.h @@ -32,7 +32,16 @@ * Android logcat. */ #define TVM_LOG_CUSTOMIZE 1 +#define TVM_FFI_USE_LIBBACKTRACE 0 +#include "../ffi/src/ffi/container.cc" +#include "../ffi/src/ffi/dtype.cc" +#include "../ffi/src/ffi/error.cc" +#include "../ffi/src/ffi/function.cc" +#include "../ffi/src/ffi/ndarray.cc" +#include "../ffi/src/ffi/object.cc" +#include "../ffi/src/ffi/testing.cc" +#include "../ffi/src/ffi/traceback.cc" #include "../src/runtime/c_runtime_api.cc" #include "../src/runtime/container.cc" #include "../src/runtime/cpu_device_api.cc" diff --git a/apps/cpp_rpc/rpc_env.cc b/apps/cpp_rpc/rpc_env.cc index ae70401fda19..146b9cf05071 100644 --- a/apps/cpp_rpc/rpc_env.cc +++ b/apps/cpp_rpc/rpc_env.cc @@ -126,46 +126,44 @@ RPCEnv::RPCEnv(const std::string& wd) { mkdir(base_.c_str(), 0777); } - TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath").set_body([this](TVMArgs args, TVMRetValue* rv) { - *rv = this->GetPath(args[0]); - }); + ffi::Function::SetGlobal( + "tvm.rpc.server.workpath", + ffi::Function::FromUnpacked([this](const std::string& path) { return this->GetPath(path); })); - TVM_REGISTER_GLOBAL("tvm.rpc.server.listdir").set_body([this](TVMArgs args, TVMRetValue* rv) { - std::string dir = this->GetPath(args[0]); - std::ostringstream os; - for (auto d : ListDir(dir)) { - os << d << ","; - } - *rv = os.str(); - }); - - TVM_REGISTER_GLOBAL("tvm.rpc.server.load_module").set_body([this](TVMArgs args, TVMRetValue* rv) { - std::string file_name = this->GetPath(args[0]); - file_name = BuildSharedLibrary(file_name); - *rv = Module::LoadFromFile(file_name, ""); - LOG(INFO) << "Load module from " << file_name << " ..."; - }); + ffi::Function::SetGlobal("tvm.rpc.server.listdir", + ffi::Function::FromUnpacked([this](const std::string& path) { + std::string dir = this->GetPath(path); + std::ostringstream os; + for (auto d : ListDir(dir)) { + os << d << ","; + } + return os.str(); + })); - TVM_REGISTER_GLOBAL("tvm.rpc.server.download_linked_module") - .set_body([this](TVMArgs args, TVMRetValue* rv) { - std::string file_name = this->GetPath(args[0]); - file_name = BuildSharedLibrary(file_name); - std::string bin; + ffi::Function::SetGlobal("tvm.rpc.server.load_module", + ffi::Function::FromUnpacked([this](const std::string& path) { + std::string file_name = this->GetPath(path); + file_name = BuildSharedLibrary(file_name); + LOG(INFO) << "Load module from " << file_name << " ..."; + return Module::LoadFromFile(file_name, ""); + })); - std::ifstream fs(file_name, std::ios::in | std::ios::binary); - ICHECK(!fs.fail()) << "Cannot open " << file_name; - fs.seekg(0, std::ios::end); - size_t size = static_cast(fs.tellg()); - fs.seekg(0, std::ios::beg); - bin.resize(size); - fs.read(dmlc::BeginPtr(bin), size); + ffi::Function::SetGlobal("tvm.rpc.server.download_linked_module", + ffi::Function::FromUnpacked([this](const std::string& path) { + std::string file_name = this->GetPath(path); + file_name = BuildSharedLibrary(file_name); + std::string bin; - TVMByteArray binarr; - binarr.data = bin.data(); - binarr.size = bin.length(); - *rv = binarr; - LOG(INFO) << "Send linked module " << file_name << " to client"; - }); + std::ifstream fs(file_name, std::ios::in | std::ios::binary); + ICHECK(!fs.fail()) << "Cannot open " << file_name; + fs.seekg(0, std::ios::end); + size_t size = static_cast(fs.tellg()); + fs.seekg(0, std::ios::beg); + bin.resize(size); + fs.read(dmlc::BeginPtr(bin), size); + LOG(INFO) << "Send linked module " << file_name << " to client"; + return ffi::Bytes(bin); + })); } /*! * \brief GetPath To get the work path from packed function diff --git a/apps/cpp_rpc/rpc_server.cc b/apps/cpp_rpc/rpc_server.cc index 217baf133bf1..c4ee4d35450f 100644 --- a/apps/cpp_rpc/rpc_server.cc +++ b/apps/cpp_rpc/rpc_server.cc @@ -398,8 +398,6 @@ void RPCServerCreate(std::string host, int port, int port_end, std::string track rpc.Start(); } -TVM_REGISTER_GLOBAL("rpc.ServerCreate").set_body([](TVMArgs args, TVMRetValue* rv) { - RPCServerCreate(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7]); -}); +TVM_REGISTER_GLOBAL("rpc.ServerCreate").set_body_typed(RPCServerCreate); } // namespace runtime } // namespace tvm diff --git a/apps/hexagon_api/CMakeLists.txt b/apps/hexagon_api/CMakeLists.txt index f7144835dbe0..4dfba8669b66 100644 --- a/apps/hexagon_api/CMakeLists.txt +++ b/apps/hexagon_api/CMakeLists.txt @@ -46,7 +46,7 @@ ExternalProject_Add(x86_tvm_runtime_rpc "-DCMAKE_CXX_COMPILER_LAUNCHER=${CMAKE_CXX_COMPILER_LAUNCHER}" "-DUSE_HEXAGON_TOOLCHAIN=${USE_HEXAGON_TOOLCHAIN}" "-DCMAKE_CXX_STANDARD=17" - "-DUSE_LIBBACKTRACE=OFF" + "-DTVM_FFI_USE_LIBBACKTRACE=OFF" "-DUSE_RPC=ON" "-DUSE_CPP_RPC=ON" "-DUSE_HEXAGON=ON" @@ -80,7 +80,7 @@ ExternalProject_Add(android_tvm_runtime_rpc "-DUSE_HEXAGON_SDK=${USE_HEXAGON_SDK}" "-DUSE_HEXAGON_ARCH=${USE_HEXAGON_ARCH}" "-DCMAKE_CXX_STANDARD=17" - "-DUSE_LIBBACKTRACE=OFF" + "-DTVM_FFI_USE_LIBBACKTRACE=OFF" "-DUSE_RPC=ON" "-DUSE_CPP_RPC=ON" "-DUSE_HEXAGON=ON" @@ -132,7 +132,7 @@ ExternalProject_Add(hexagon_tvm_runtime_rpc "-DUSE_HEXAGON_EXTERNAL_LIBS=${USE_HEXAGON_EXTERNAL_LIBS}" "-DHEXAGON_EXTERNAL_LIBS_SHA=${HEXAGON_EXTERNAL_LIBS_SHA}" "-DCMAKE_CXX_STANDARD=17" - "-DUSE_LIBBACKTRACE=OFF" + "-DTVM_FFI_USE_LIBBACKTRACE=OFF" "-DUSE_RPC=OFF" "-DUSE_HEXAGON=ON" "-DUSE_HEXAGON_RPC=ON" diff --git a/apps/hexagon_launcher/cmake/android/CMakeLists.txt b/apps/hexagon_launcher/cmake/android/CMakeLists.txt index 78d3cb396cfd..42b5feb03a9b 100644 --- a/apps/hexagon_launcher/cmake/android/CMakeLists.txt +++ b/apps/hexagon_launcher/cmake/android/CMakeLists.txt @@ -78,7 +78,7 @@ ExternalProject_Add(android_tvm_runtime "-DUSE_HEXAGON=ON" "-DUSE_HEXAGON_ARCH=${USE_HEXAGON_ARCH}" "-DUSE_HEXAGON_SDK=${USE_HEXAGON_SDK}" - "-DUSE_LIBBACKTRACE=OFF" + "-DTVM_FFI_USE_LIBBACKTRACE=OFF" "-DUSE_LLVM=OFF" "-DUSE_RPC=OFF" INSTALL_COMMAND "" diff --git a/apps/hexagon_launcher/cmake/hexagon/CMakeLists.txt b/apps/hexagon_launcher/cmake/hexagon/CMakeLists.txt index e8bd67dde7a2..d17b5186aac5 100644 --- a/apps/hexagon_launcher/cmake/hexagon/CMakeLists.txt +++ b/apps/hexagon_launcher/cmake/hexagon/CMakeLists.txt @@ -89,7 +89,7 @@ ExternalProject_Add(static_hexagon_tvm_runtime "-DUSE_HEXAGON=ON" "-DUSE_HEXAGON_ARCH=${USE_HEXAGON_ARCH}" "-DUSE_HEXAGON_SDK=${USE_HEXAGON_SDK}" - "-DUSE_LIBBACKTRACE=OFF" + "-DTVM_FFI_USE_LIBBACKTRACE=OFF" "-DUSE_LLVM=OFF" "-DUSE_RPC=OFF" "-DUSE_CUSTOM_LOGGING=ON" diff --git a/apps/hexagon_launcher/launcher_core.cc b/apps/hexagon_launcher/launcher_core.cc index a3993451c212..5b39ace9ba6c 100644 --- a/apps/hexagon_launcher/launcher_core.cc +++ b/apps/hexagon_launcher/launcher_core.cc @@ -37,7 +37,7 @@ const std::string TensorConfig::dtype_key = "dtype"; // NOLINT(runtime/string) std::string tensor_meta::to_string() const { std::stringstream out; - out << "ndim=" << ndim << ", dtype=" << tvm::runtime::DLDataType2String(dtype) << ", shape="; + out << "ndim=" << ndim << ", dtype=" << tvm::runtime::DLDataTypeToString(dtype) << ", shape="; for (int i = 0; i != ndim; ++i) { out << shape[i]; if (i + 1 < ndim) { @@ -138,7 +138,7 @@ Model::Model(tvm::runtime::Module executor, tvm::runtime::Module module, std::st } const tvm::runtime::PackedFunc get_runtime_func(const std::string& name) { - if (const tvm::runtime::PackedFunc* pf = tvm::runtime::Registry::Get(name)) { + if (auto pf = tvm::ffi::Function::GetGlobal(name)) { return *pf; } return tvm::runtime::PackedFunc(); @@ -151,7 +151,7 @@ const tvm::runtime::PackedFunc get_module_func(tvm::runtime::Module module, void reset_device_api() { const tvm::runtime::PackedFunc api = get_runtime_func("device_api.hexagon"); - tvm::runtime::Registry::Register("device_api.cpu", true).set_body(api); + tvm::ffi::Function::SetGlobal("device_api.cpu", api, true); } tvm::runtime::Module load_module(const std::string& file_name) { diff --git a/apps/hexagon_launcher/launcher_main.cc b/apps/hexagon_launcher/launcher_main.cc index 1ef3b5d2ff3c..8690996684b2 100644 --- a/apps/hexagon_launcher/launcher_main.cc +++ b/apps/hexagon_launcher/launcher_main.cc @@ -102,7 +102,7 @@ int main(int argc, char* argv[]) { for (int i = 0, e = config.inputs.size(); i != e; ++i) { const TensorConfig& tc = config.inputs[i]; input_meta->ndim = tc.shape.size(); - input_meta->dtype = tvm::runtime::String2DLDataType(tc.dtype); + input_meta->dtype = tvm::ffi::StringToDLDataType(tc.dtype); std::copy(tc.shape.begin(), tc.shape.end(), input_meta->shape); auto* input_data = session.alloc(input_meta->data_size()); @@ -145,7 +145,7 @@ int main(int argc, char* argv[]) { for (int i = 0, e = output_meta->ndim; i != e; ++i) { oc.shape.push_back(output_meta->shape[i]); } - oc.dtype = tvm::runtime::DLDataType2String(output_meta->dtype); + oc.dtype = tvm::runtime::DLDataTypeToString(output_meta->dtype); write_binary_file(oc.file_name, output_data, data_size); output_config.outputs.push_back(std::move(oc)); diff --git a/apps/ios_rpc/tvmrpc/RPCServer.mm b/apps/ios_rpc/tvmrpc/RPCServer.mm index 5dcf8c1eb9f8..284a2cfcd9ee 100644 --- a/apps/ios_rpc/tvmrpc/RPCServer.mm +++ b/apps/ios_rpc/tvmrpc/RPCServer.mm @@ -63,8 +63,8 @@ */ FEventHandler CreateServerEventHandler(NSOutputStream* outputStream, std::string name, std::string remote_key) { - const PackedFunc* event_handler_factory = Registry::Get("rpc.CreateEventDrivenServer"); - ICHECK(event_handler_factory != nullptr) + auto event_handler_factory = tvm::ffi::Function::GetGlobal("rpc.CreateEventDrivenServer"); + ICHECK(event_handler_factory.has_value()) << "You are using tvm_runtime module built without RPC support. " << "Please rebuild it with USE_RPC flag."; diff --git a/apps/ios_rpc/tvmrpc/TVMRuntime.mm b/apps/ios_rpc/tvmrpc/TVMRuntime.mm index 4e1d66ed6400..a14bb50e0b2c 100644 --- a/apps/ios_rpc/tvmrpc/TVMRuntime.mm +++ b/apps/ios_rpc/tvmrpc/TVMRuntime.mm @@ -51,36 +51,37 @@ void LogMessageImpl(const std::string& file, int lineno, int level, const std::s } // namespace detail -TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath").set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath").set_body_packed([](TVMArgs args, TVMRetValue* rv) { static const std::string base_ = NSTemporaryDirectory().UTF8String; - const std::string path = args[0]; + const auto path = args[0].cast(); *rv = base_ + "/" + path; }); -TVM_REGISTER_GLOBAL("tvm.rpc.server.load_module").set_body([](TVMArgs args, TVMRetValue* rv) { - std::string name = args[0]; - std::string fmt = GetFileFormat(name, ""); - NSString* base; - if (fmt == "dylib") { - // only load dylib from frameworks. - NSBundle* bundle = [NSBundle mainBundle]; - base = [[bundle privateFrameworksPath] stringByAppendingPathComponent:@"tvm"]; - - if (Registry::Get("runtime.module.loadfile_dylib_custom")) { - // Custom dso loader is present. Will use it. - base = NSTemporaryDirectory(); - fmt = "dylib_custom"; - } - } else { - // Load other modules in tempdir. - base = NSTemporaryDirectory(); - } - NSString* path = - [base stringByAppendingPathComponent:[NSString stringWithUTF8String:name.c_str()]]; - name = [path UTF8String]; - *rv = Module::LoadFromFile(name, fmt); - LOG(INFO) << "Load module from " << name << " ..."; -}); +TVM_REGISTER_GLOBAL("tvm.rpc.server.load_module") + .set_body_packed([](TVMArgs args, TVMRetValue* rv) { + auto name = args[0].cast(); + std::string fmt = GetFileFormat(name, ""); + NSString* base; + if (fmt == "dylib") { + // only load dylib from frameworks. + NSBundle* bundle = [NSBundle mainBundle]; + base = [[bundle privateFrameworksPath] stringByAppendingPathComponent:@"tvm"]; + + if (tvm::ffi::Function::GetGlobal("runtime.module.loadfile_dylib_custom")) { + // Custom dso loader is present. Will use it. + base = NSTemporaryDirectory(); + fmt = "dylib_custom"; + } + } else { + // Load other modules in tempdir. + base = NSTemporaryDirectory(); + } + NSString* path = + [base stringByAppendingPathComponent:[NSString stringWithUTF8String:name.c_str()]]; + name = [path UTF8String]; + *rv = Module::LoadFromFile(name, fmt); + LOG(INFO) << "Load module from " << name << " ..."; + }); #if defined(USE_CUSTOM_DSO_LOADER) && USE_CUSTOM_DSO_LOADER == 1 @@ -108,7 +109,7 @@ void Init(const std::string& name) { // Add UnsignedDSOLoader plugin in global registry TVM_REGISTER_GLOBAL("runtime.module.loadfile_dylib_custom") - .set_body([](TVMArgs args, TVMRetValue* rv) { + .set_body_packed([](TVMArgs args, TVMRetValue* rv) { auto n = make_object(); n->Init(args[0]); *rv = CreateModuleFromLibrary(n); diff --git a/cmake/modules/Logging.cmake b/cmake/modules/Logging.cmake index ae4b56118699..2b092f665e48 100644 --- a/cmake/modules/Logging.cmake +++ b/cmake/modules/Logging.cmake @@ -27,90 +27,3 @@ if(USE_CUSTOM_LOGGING) target_compile_definitions(tvm PUBLIC TVM_LOG_CUSTOMIZE=1) target_compile_definitions(tvm_runtime PUBLIC TVM_LOG_CUSTOMIZE=1) endif() - -add_library(libbacktrace STATIC IMPORTED) - -set(LIBBACKTRACE_INCLUDE_DIR NOTFOUND) -set(LIBBACKTRACE_STATIC_LIBRARY NOTFOUND) -set(LIBBACKTRACE_FOUND NO) - -macro(__find_libbacktrace) - find_path(LIBBACKTRACE_INCLUDE_DIR backtrace.h) - find_library(LIBBACKTRACE_STATIC_LIBRARY libbacktrace.a) - find_package_handle_standard_args(LIBBACKTRACE REQUIRED_VARS - LIBBACKTRACE_STATIC_LIBRARY LIBBACKTRACE_INCLUDE_DIR) -endmacro() - -macro(__find_libbacktrace_from PATH) - find_path(LIBBACKTRACE_INCLUDE_DIR backtrace.h - PATHS ${PATH} - PATH_SUFFIXES include - NO_CMAKE_SYSTEM_PATH - NO_SYSTEM_ENVIRONMENT_PATH - ) - find_library(LIBBACKTRACE_STATIC_LIBRARY libbacktrace.a - PATHS ${PATH} - PATH_SUFFIXES lib - NO_CMAKE_SYSTEM_PATH - NO_SYSTEM_ENVIRONMENT_PATH - ) - find_package_handle_standard_args(LIBBACKTRACE REQUIRED_VARS - LIBBACKTRACE_STATIC_LIBRARY LIBBACKTRACE_INCLUDE_DIR) -endmacro() - -macro(__compile_libbacktrace) - message(STATUS "Building libbacktrace from 3rdparty/libbacktrace") - include(cmake/libs/Libbacktrace.cmake) - add_dependencies(libbacktrace project_libbacktrace) - set(LIBBACKTRACE_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/libbacktrace/include) - set(LIBBACKTRACE_STATIC_LIBRARY ${CMAKE_CURRENT_BINARY_DIR}/libbacktrace/lib/libbacktrace.a) - add_dependencies(tvm_runtime_objs libbacktrace) - set(LIBBACKTRACE_FOUND YES) -endmacro() - -if(USE_LIBBACKTRACE STREQUAL "AUTO") - __find_libbacktrace() - if(NOT LIBBACKTRACE_FOUND AND (CMAKE_SYSTEM_NAME MATCHES "Linux" OR CMAKE_SYSTEM_NAME MATCHES "Darwin")) - __compile_libbacktrace() - endif() -elseif(USE_LIBBACKTRACE STREQUAL "COMPILE") - __compile_libbacktrace() -elseif("${USE_LIBBACKTRACE}" MATCHES ${IS_TRUE_PATTERN}) - __find_libbacktrace() - if(NOT LIBBACKTRACE_FOUND) - message(SEND_ERROR "libbacktrace not found. (Set USE_LIBBACKTRACE to COMPILE if you want to build with the submodule at 3rdparty/libbacktrace.)") - endif() -elseif("${USE_LIBBACKTRACE}" MATCHES ${IS_FALSE_PATTERN}) -else() - # Treat USE_LIBBACKTRACE as path to libbacktrace - message(STATUS "Using libbacktrace from ${USE_LIBBACKTRACE}") - __find_libbacktrace_from(${USE_LIBBACKTRACE}) - if(NOT LIBBACKTRACE_FOUND) - message(SEND_ERROR "libbacktrace not found from ${USE_LIBBACKTRACE}.") - endif() -endif() - -set_property(TARGET libbacktrace - PROPERTY IMPORTED_LOCATION ${LIBBACKTRACE_STATIC_LIBRARY}) - -function(configure_backtrace TARGET) - if(LIBBACKTRACE_FOUND) - get_target_property(target_type ${TARGET} TYPE) - if(target_type MATCHES "EXECUTABLE|(STATIC|SHARED|MODULE)_LIBRARY") - target_link_libraries(${TARGET} PRIVATE libbacktrace) - endif() - target_include_directories(${TARGET} PRIVATE ${LIBBACKTRACE_INCLUDE_DIR}) - target_compile_definitions(${TARGET} PRIVATE TVM_USE_LIBBACKTRACE=1) - else() - target_compile_definitions(${TARGET} PRIVATE TVM_USE_LIBBACKTRACE=0) - endif() - - if(BACKTRACE_ON_SEGFAULT) - target_compile_definitions(${TARGET} PRIVATE TVM_BACKTRACE_ON_SEGFAULT) - endif() -endfunction() - -configure_backtrace(tvm) -configure_backtrace(tvm_runtime) -configure_backtrace(tvm_objs) -configure_backtrace(tvm_runtime_objs) diff --git a/cmake/modules/contrib/CUTLASS.cmake b/cmake/modules/contrib/CUTLASS.cmake index b9097a02e93f..d11777e8514a 100644 --- a/cmake/modules/contrib/CUTLASS.cmake +++ b/cmake/modules/contrib/CUTLASS.cmake @@ -33,9 +33,12 @@ if(USE_CUDA AND USE_CUTLASS) ${PROJECT_SOURCE_DIR}/3rdparty/cutlass_fpA_intB_gemm ${PROJECT_SOURCE_DIR}/3rdparty/cutlass_fpA_intB_gemm/cutlass/include ) + target_link_libraries(fpA_intB_gemm_tvm PRIVATE tvm_ffi_header) + set(CUTLASS_FPA_INTB_RUNTIME_SRCS "") list(APPEND CUTLASS_FPA_INTB_RUNTIME_SRCS src/runtime/contrib/cutlass/weight_preprocess.cc) add_library(fpA_intB_cutlass_objs OBJECT ${CUTLASS_FPA_INTB_RUNTIME_SRCS}) + target_link_libraries(fpA_intB_cutlass_objs PRIVATE tvm_ffi_header) target_compile_definitions(fpA_intB_cutlass_objs PRIVATE DMLC_USE_LOGGING_LIBRARY=) target_include_directories(fpA_intB_cutlass_objs PRIVATE ${PROJECT_SOURCE_DIR}/3rdparty/cutlass_fpA_intB_gemm diff --git a/cmake/utils/FindLLVM.cmake b/cmake/utils/FindLLVM.cmake index 182a2c66934e..2a243b06c85d 100644 --- a/cmake/utils/FindLLVM.cmake +++ b/cmake/utils/FindLLVM.cmake @@ -199,7 +199,7 @@ macro(find_llvm use_llvm) message(STATUS "LLVM links against zlib") find_package(ZLIB REQUIRED) list(APPEND LLVM_LIBS "ZLIB::ZLIB") - elseif("${__flag}" STREQUAL "-lzstd" OR ("${__flag}" STREQUAL "zstd.dll.lib")) + elseif("${__flag}" STREQUAL "-lzstd") list(APPEND CMAKE_MODULE_PATH "${__llvm_cmakedir}") find_package(zstd REQUIRED) if (TARGET "zstd::libzstd_static") @@ -212,6 +212,9 @@ macro(find_llvm use_llvm) elseif("${__flag}" STREQUAL "-lxml2") message(STATUS "LLVM links against xml2") list(APPEND LLVM_LIBS "-lxml2") + elseif("${__flag}" STREQUAL "zstd.dll.lib") + message(STATUS "LLVM linker flag under LLVM libdir: ${__llvm_libdir}/zstd.lib") + list(APPEND LLVM_LIBS "${__llvm_libdir}/zstd.lib") elseif((__flag MATCHES ".lib$") AND (EXISTS "${__llvm_libdir}/${__flag}")) # If the library file ends in .lib try to also search the llvm_libdir message(STATUS "LLVM linker flag under LLVM libdir: ${__llvm_libdir}/${__flag}") diff --git a/docs/arch/device_target_interactions.rst b/docs/arch/device_target_interactions.rst index ec8d52226edd..e39468f0bf78 100644 --- a/docs/arch/device_target_interactions.rst +++ b/docs/arch/device_target_interactions.rst @@ -166,7 +166,7 @@ then be registered with the following steps. enum value to a string representation. This string representation should match the name given to ``TVM_REGISTER_GLOBAL``. -#. Add entries to the ``MASK2STR`` and ``STR2MASK`` dictionaries of +#. Add entries to the ``DEVICE_TYPE_TO_NAME`` and ``DEVICE_NAME_TO_TYPE`` dictionaries of :py:class:`tvm.runtime.Device` for the new enum value. diff --git a/docs/arch/pass_infra.rst b/docs/arch/pass_infra.rst index 1df8e2e105a7..85e9f45a5fba 100644 --- a/docs/arch/pass_infra.rst +++ b/docs/arch/pass_infra.rst @@ -128,7 +128,7 @@ Python APIs to create a compilation pipeline using pass context. tvm::Array required_pass; tvm::Array disabled_pass; mutable Optional diag_ctx; - Map config; + Map config; Array instruments; }; @@ -209,7 +209,7 @@ passes class ModulePassNode : PassNode { PassInfo pass_info; - runtime::TypedPackedFunc pass_func; + std::function pass_func; Module operator()(const Module& mod, const PassContext& pass_ctx) const final; // Other members/methods are omitted }; @@ -240,7 +240,7 @@ the global information. class FunctionPassNode : PassNode { PassInfo pass_info; - runtime::TypedPackedFunc pass_func; + std::function pass_func; Module operator()(const Module& mod, const PassContext& pass_ctx) const final; bool SkipFunction(const Function& func) const; // Other members/methods are omitted... @@ -306,9 +306,9 @@ pass is registered with an API endpoint as we will show later. Pass GetPass(const std::string& pass_name) { using tvm::runtime::Registry; std::string fpass_name = "relax.transform." + pass_name; - const auto* f = Registry::Get(fpass_name); - ICHECK(f != nullptr) << "Cannot find " << fpass_name - << "to create the pass " << pass_name; + const std::optional f = tvm::ffi::Function::GetGlobal(fpass_name); + ICHECK(f.has_value()) << "Cannot find " << fpass_name + << "to create the pass " << pass_name; return (*f)(); } @@ -319,19 +319,19 @@ favorably use Python APIs to create a specific pass object. .. code:: c++ Pass CreateFunctionPass( - const runtime::TypedPackedFunc& pass_func, + std::function pass_func, int opt_level, String name, Array required); Pass CreatePrimFuncPass( - const runtime::TypedPackedFunc& pass_func, + std::function pass_func, int opt_level, String name, Array required); Pass CreateModulePass( - const runtime::TypedPackedFunc& pass_func, + std::function pass_func, int opt_level, String name, Array required); @@ -371,7 +371,7 @@ Python when needed. namespace transform { Pass FoldConstant() { - runtime::TypedPackedFunc pass_func = + auto pass_func = [=](Function f, IRModule m, PassContext pc) { return ConstantFolder::Fold(f, m); }; return CreateFunctionPass(pass_func, 0, "FoldConstant", {}); } diff --git a/docs/arch/runtime.rst b/docs/arch/runtime.rst index dfda00c1d6c4..90e02a4f0ce7 100644 --- a/docs/arch/runtime.rst +++ b/docs/arch/runtime.rst @@ -56,8 +56,8 @@ The following code block provides an example in C++ void MyAdd(TVMArgs args, TVMRetValue* rv) { // automatically convert arguments to desired type. - int a = args[0]; - int b = args[1]; + int a = args[0].cast(); + int b = args[1].cast(); // automatically assign value return to rv *rv = a + b; } @@ -81,7 +81,7 @@ The following example registers PackedFunc in C++ and calls from python. // register a global packed function in c++ TVM_REGISTER_GLOBAL("myadd") - .set_body(MyAdd); + .set_body_packed(MyAdd); .. code:: python @@ -111,7 +111,7 @@ we can pass functions from python (as PackedFunc) to C++. .. code:: c TVM_REGISTER_GLOBAL("callhello") - .set_body([](TVMArgs args, TVMRetValue* rv) { + .set_body_packed([](TVMArgs args, TVMRetValue* rv) { PackedFunc f = args[0]; f("hello world"); }); diff --git a/docs/deep_dive/tensor_ir/tutorials/tir_transformation.py b/docs/deep_dive/tensor_ir/tutorials/tir_transformation.py index 9dae16294426..702b53011b48 100644 --- a/docs/deep_dive/tensor_ir/tutorials/tir_transformation.py +++ b/docs/deep_dive/tensor_ir/tutorials/tir_transformation.py @@ -48,7 +48,7 @@ def main( B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32"), ): - T.func_attr({"tir.noalias": T.bool(True)}) + T.func_attr({"tir.noalias": True}) Y = T.alloc_buffer((128, 128)) for i, j, k in T.grid(128, 128, 128): with T.block("Y"): diff --git a/ffi/3rdparty/dlpack b/ffi/3rdparty/dlpack new file mode 160000 index 000000000000..3ea601bb4130 --- /dev/null +++ b/ffi/3rdparty/dlpack @@ -0,0 +1 @@ +Subproject commit 3ea601bb413074c49a77c4ce3218bc08f8c4703c diff --git a/ffi/CMakeLists.txt b/ffi/CMakeLists.txt new file mode 100644 index 000000000000..abdc4fed052b --- /dev/null +++ b/ffi/CMakeLists.txt @@ -0,0 +1,133 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +cmake_minimum_required(VERSION 3.14) + +project( + tvm_ffi + VERSION 1.0 + DESCRIPTION "TVM's FFI system" + LANGUAGES CXX C +) + +option(TVM_FFI_BUILD_TESTS "Adding test targets." OFF) + +########## NOTE: all options below are related to dynamic registry ##### +option(TVM_FFI_BUILD_REGISTRY + "Support for objects with non-static type indices. When turned on, \ + targets linked against `tvm_ffi` will allow objects that comes with non-pre-defined type indices, \ + as well as getting full stacktrace during debugging. \ + so that the object hierarchy could expand without limitation. \ + This will require the downstream targets to link against target `tvm_ffi` to be effective." + OFF +) +option(TVM_FFI_USE_LIBBACKTRACE "Enable libbacktrace" ON) +option(TVM_FFI_BACKTRACE_ON_SEGFAULT "Set signal handler to print traceback on segfault" ON) + +include(cmake/Utils/CxxWarning.cmake) +include(cmake/Utils/Sanitizer.cmake) +include(cmake/Utils/Library.cmake) +if (TVM_FFI_USE_LIBBACKTRACE) + include(cmake/Utils/AddLibbacktrace.cmake) +endif() + +########## Target: `dlpack_header` ########## + +add_library(dlpack_header INTERFACE) +target_include_directories(dlpack_header INTERFACE "${CMAKE_CURRENT_SOURCE_DIR}/3rdparty/dlpack/include") + +########## Target: `tvm_ffi_header` ########## + +add_library(tvm_ffi_header INTERFACE) +target_include_directories(tvm_ffi_header INTERFACE "${CMAKE_CURRENT_SOURCE_DIR}/include") +target_link_libraries(tvm_ffi_header INTERFACE dlpack_header) + +########## Target: `tvm_ffi` ########## +add_library(tvm_ffi_objs OBJECT + "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/traceback.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/traceback_win.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/object.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/error.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/function.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/ndarray.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/dtype.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/testing.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/container.cc" +) +set_target_properties( + tvm_ffi_objs PROPERTIES + POSITION_INDEPENDENT_CODE ON + CXX_STANDARD 17 + CXX_EXTENSIONS OFF + CXX_STANDARD_REQUIRED ON + CXX_VISIBILITY_PRESET hidden + VISIBILITY_INLINES_HIDDEN ON + PREFIX "lib" +) +add_cxx_warning(tvm_ffi_objs) +target_link_libraries(tvm_ffi_objs PRIVATE dlpack_header) +target_include_directories(tvm_ffi_objs PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/include") + +if (TVM_FFI_USE_LIBBACKTRACE) + message(STATUS "Setting C++ macro TVM_FFI_USE_LIBBACKTRACE - 1") + target_compile_definitions(tvm_ffi_objs PRIVATE TVM_FFI_USE_LIBBACKTRACE=1) +else() + message(STATUS "Setting C++ macro TVM_FFI_USE_LIBBACKTRACE - 0") + target_compile_definitions(tvm_ffi_objs PRIVATE TVM_FFI_USE_LIBBACKTRACE=0) +endif() + +if (TVM_FFI_BACKTRACE_ON_SEGFAULT) + message(STATUS "Setting C++ macro TVM_FFI_BACKTRACE_ON_SEGFAULT - 1") + target_compile_definitions(tvm_ffi_objs PRIVATE TVM_FFI_BACKTRACE_ON_SEGFAULT=1) +else() + message(STATUS "Setting C++ macro TVM_FFI_BACKTRACE_ON_SEGFAULT - 0") + target_compile_definitions(tvm_ffi_objs PRIVATE TVM_FFI_BACKTRACE_ON_SEGFAULT=0) +endif() + +add_target_from_obj(tvm_ffi tvm_ffi_objs) + +if (TARGET libbacktrace) + target_link_libraries(tvm_ffi_objs PRIVATE libbacktrace) + target_link_libraries(tvm_ffi_shared PRIVATE libbacktrace) + target_link_libraries(tvm_ffi_static PRIVATE libbacktrace) +endif () + +if (MSVC) + target_link_libraries(tvm_ffi_objs PRIVATE DbgHelp.lib) + target_link_libraries(tvm_ffi_shared PRIVATE DbgHelp.lib) + target_link_libraries(tvm_ffi_static PRIVATE DbgHelp.lib) +endif () + +target_link_libraries(tvm_ffi_objs PUBLIC tvm_ffi_header) +target_link_libraries(tvm_ffi_shared PUBLIC tvm_ffi_header) +target_link_libraries(tvm_ffi_static PUBLIC tvm_ffi_header) + +install(TARGETS tvm_ffi_static DESTINATION lib${LIB_SUFFIX}) +install(TARGETS tvm_ffi_shared DESTINATION lib${LIB_SUFFIX}) + +add_msvc_flags(tvm_ffi_objs) + +########## Adding tests ########## + +if (${PROJECT_NAME} STREQUAL ${CMAKE_PROJECT_NAME}) + if (TVM_FFI_BUILD_TESTS) + enable_testing() + message(STATUS "Enable Testing") + include(cmake/Utils/AddGoogleTest.cmake) + add_subdirectory(tests/cpp/) + endif() +endif () diff --git a/ffi/cmake/Utils/AddGoogleTest.cmake b/ffi/cmake/Utils/AddGoogleTest.cmake new file mode 100644 index 000000000000..10e59386128b --- /dev/null +++ b/ffi/cmake/Utils/AddGoogleTest.cmake @@ -0,0 +1,59 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +include(FetchContent) +set(gtest_force_shared_crt ON CACHE BOOL "Always use msvcrt.dll" FORCE) +set(BUILD_GMOCK ON CACHE BOOL "" FORCE) +set(BUILD_GTEST ON CACHE BOOL "" FORCE) +FetchContent_Declare( + googletest + GIT_REPOSITORY https://github.com/google/googletest.git + GIT_TAG v1.14.0 +) +FetchContent_GetProperties(googletest) +if (NOT googletest_POPULATED) + FetchContent_Populate(googletest) + message(STATUS "Found googletest_SOURCE_DIR - ${googletest_SOURCE_DIR}") + message(STATUS "Found googletest_BINARY_DIR - ${googletest_BINARY_DIR}") + add_subdirectory(${googletest_SOURCE_DIR} ${googletest_BINARY_DIR}) + include(GoogleTest) + set_target_properties(gtest PROPERTIES EXPORT_COMPILE_COMMANDS OFF EXCLUDE_FROM_ALL ON FOLDER 3rdparty) + set_target_properties(gtest_main PROPERTIES EXPORT_COMPILE_COMMANDS OFF EXCLUDE_FROM_ALL ON FOLDER 3rdparty) + set_target_properties(gmock PROPERTIES EXPORT_COMPILE_COMMANDS OFF EXCLUDE_FROM_ALL ON FOLDER 3rdparty) + set_target_properties(gmock_main PROPERTIES EXPORT_COMPILE_COMMANDS OFF EXCLUDE_FROM_ALL ON FOLDER 3rdparty) + mark_as_advanced( + BUILD_GMOCK BUILD_GTEST BUILD_SHARED_LIBS + gmock_build_tests gtest_build_samples gtest_build_tests + gtest_disable_pthreads gtest_force_shared_crt gtest_hide_internal_symbols + ) +endif() + +macro(add_googletest target_name) + add_test( + NAME ${target_name} + COMMAND ${target_name} + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + ) + target_link_libraries(${target_name} PRIVATE gtest_main) + gtest_discover_tests(${target_name} + WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} + DISCOVERY_MODE PRE_TEST + PROPERTIES + VS_DEBUGGER_WORKING_DIRECTORY "${PROJECT_SOURCE_DIR}" + ) + set_target_properties(${target_name} PROPERTIES FOLDER tests) +endmacro() diff --git a/ffi/cmake/Utils/AddLibbacktrace.cmake b/ffi/cmake/Utils/AddLibbacktrace.cmake new file mode 100644 index 000000000000..844a8816a6d8 --- /dev/null +++ b/ffi/cmake/Utils/AddLibbacktrace.cmake @@ -0,0 +1,67 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +include(ExternalProject) + +function(_libbacktrace_compile) + set(_libbacktrace_source ${CMAKE_CURRENT_LIST_DIR}/../../../3rdparty/libbacktrace) + set(_libbacktrace_prefix ${CMAKE_CURRENT_BINARY_DIR}/libbacktrace) + if(CMAKE_SYSTEM_NAME MATCHES "Darwin" AND (CMAKE_C_COMPILER MATCHES "^/Library" OR CMAKE_C_COMPILER MATCHES "^/Applications")) + set(_cmake_c_compiler "/usr/bin/cc") + else() + set(_cmake_c_compiler "${CMAKE_C_COMPILER}") + endif() + + message(STATUS CMAKC_C_COMPILER="${CMAKE_C_COMPILER}") + + file(MAKE_DIRECTORY ${_libbacktrace_prefix}/include) + file(MAKE_DIRECTORY ${_libbacktrace_prefix}/lib) + + ExternalProject_Add(project_libbacktrace + PREFIX libbacktrace + SOURCE_DIR ${_libbacktrace_source} + BINARY_DIR ${_libbacktrace_prefix} + CONFIGURE_COMMAND + "${_libbacktrace_source}/configure" + "--prefix=${_libbacktrace_prefix}" + --with-pic + "CC=${_cmake_c_compiler}" + "CPP=${_cmake_c_compiler} -E" + "CFLAGS=${CMAKE_C_FLAGS}" + "LDFLAGS=${CMAKE_EXE_LINKER_FLAGS}" + "NM=${CMAKE_NM}" + "STRIP=${CMAKE_STRIP}" + "--host=${MACHINE_NAME}" + INSTALL_DIR ${_libbacktrace_prefix} + BUILD_COMMAND make + INSTALL_COMMAND make install + BUILD_BYPRODUCTS "${_libbacktrace_prefix}/lib/libbacktrace.a" + "${_libbacktrace_prefix}/include/backtrace.h" + ) + ExternalProject_Add_Step(project_libbacktrace checkout DEPENDERS configure DEPENDEES download) + set_target_properties(project_libbacktrace PROPERTIES EXCLUDE_FROM_ALL TRUE) + add_library(libbacktrace STATIC IMPORTED) + add_dependencies(libbacktrace project_libbacktrace) + set_target_properties(libbacktrace PROPERTIES + IMPORTED_LOCATION ${_libbacktrace_prefix}/lib/libbacktrace.a + INTERFACE_INCLUDE_DIRECTORIES ${_libbacktrace_prefix}/include + ) +endfunction() + +if(NOT MSVC) + _libbacktrace_compile() +endif() diff --git a/ffi/cmake/Utils/CxxWarning.cmake b/ffi/cmake/Utils/CxxWarning.cmake new file mode 100644 index 000000000000..c272bfdf7bf2 --- /dev/null +++ b/ffi/cmake/Utils/CxxWarning.cmake @@ -0,0 +1,30 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function(add_cxx_warning target_name) + # GNU, Clang, or AppleClang + if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang|AppleClang") + target_compile_options(${target_name} PRIVATE "-Werror" "-Wall" "-Wextra" "-Wpedantic" "-Wno-unused-parameter") + return() + endif() + # MSVC + if(MSVC) + # target_compile_options(${target_name} PRIVATE "/W4" "/WX") + return() + endif() + message(FATAL_ERROR "Unsupported compiler: ${CMAKE_CXX_COMPILER_ID}") +endfunction() diff --git a/ffi/cmake/Utils/Library.cmake b/ffi/cmake/Utils/Library.cmake new file mode 100644 index 000000000000..f391ee8fd460 --- /dev/null +++ b/ffi/cmake/Utils/Library.cmake @@ -0,0 +1,67 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +function(add_dsymutil target_name) + # running dsymutil on macos to generate debugging symbols for backtraces + if(APPLE) + find_program(DSYMUTIL dsymutil) + mark_as_advanced(DSYMUTIL) + add_custom_command(TARGET ${target_name} + POST_BUILD + COMMAND ${DSYMUTIL} ARGS $ + COMMENT "[COMMAND] dsymutil $" + VERBATIM + ) + endif() +endfunction() + +function(add_msvc_flags target_name) + # running if we are under msvc + if(MSVC) + target_compile_definitions(${target_name} PUBLIC -DWIN32_LEAN_AND_MEAN) + target_compile_definitions(${target_name} PUBLIC -D_CRT_SECURE_NO_WARNINGS) + target_compile_definitions(${target_name} PUBLIC -D_SCL_SECURE_NO_WARNINGS) + target_compile_definitions(${target_name} PUBLIC -D_ENABLE_EXTENDED_ALIGNED_STORAGE) + target_compile_definitions(${target_name} PUBLIC -DNOMINMAX) + target_compile_options(${target_name} PRIVATE "/Z7") + endif() +endfunction() + +function(add_target_from_obj target_name obj_target_name) + add_library(${target_name}_static STATIC $) + set_target_properties( + ${target_name}_static PROPERTIES + OUTPUT_NAME "${target_name}_static" + PREFIX "lib" + ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" + LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" + ) + add_library(${target_name}_shared SHARED $) + set_target_properties( + ${target_name}_shared PROPERTIES + OUTPUT_NAME "${target_name}" + PREFIX "lib" + ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" + LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" + ) + add_custom_target(${target_name}) + add_dependencies(${target_name} ${target_name}_static ${target_name}_shared) + if (MSVC) + target_compile_definitions(${obj_target_name} PRIVATE TVM_FFI_EXPORTS) + endif() + add_dsymutil(${target_name}_shared) + add_msvc_flags(${target_name}_shared) +endfunction() diff --git a/ffi/cmake/Utils/Sanitizer.cmake b/ffi/cmake/Utils/Sanitizer.cmake new file mode 100644 index 000000000000..a20eead0c869 --- /dev/null +++ b/ffi/cmake/Utils/Sanitizer.cmake @@ -0,0 +1,35 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function(add_sanitizer_address target_name) + if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang|AppleClang") + include(CheckCXXCompilerFlag) + set (_saved_CRF ${CMAKE_REQUIRED_FLAGS}) + set(CMAKE_REQUIRED_FLAGS "-fsanitize=address") + check_cxx_source_compiles("int main() { return 0; }" COMPILER_SUPPORTS_ASAN) + set (CMAKE_REQUIRED_FLAGS ${_saved_CRF}) + get_target_property(_saved_type ${target_name} TYPE) + if (${_saved_type} STREQUAL "INTERFACE_LIBRARY") + set(_saved_type INTERFACE) + else() + set(_saved_type PRIVATE) + endif() + target_link_options(${target_name} ${_saved_type} "-fsanitize=address") + target_compile_options(${target_name} ${_saved_type} "-fsanitize=address" "-fno-omit-frame-pointer" "-g") + return() + endif() +endfunction() diff --git a/ffi/include/tvm/ffi/any.h b/ffi/include/tvm/ffi/any.h new file mode 100644 index 000000000000..8274904037bc --- /dev/null +++ b/ffi/include/tvm/ffi/any.h @@ -0,0 +1,492 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/ffi/any.h + * \brief Any value support. + */ +#ifndef TVM_FFI_ANY_H_ +#define TVM_FFI_ANY_H_ + +#include +#include +#include + +#include +#include + +namespace tvm { +namespace ffi { + +class Any; + +namespace details { +// Helper to perform +// unsafe operations related to object +struct AnyUnsafe; +} // namespace details + +/*! + * \brief AnyView allows us to take un-managed reference view of any value. + */ +class AnyView { + protected: + /*! \brief The underlying backing data of the any object */ + TVMFFIAny data_; + // Any can see AnyView + friend class Any; + + public: + // NOTE: the following two functions uses styl style + // since they are common functions appearing in FFI. + /*! + * \brief Reset any view to None + */ + void reset() { + data_.type_index = TypeIndex::kTVMFFINone; + // invariance: always set the union padding part to 0 + data_.v_int64 = 0; + } + /*! + * \brief Swap this array with another Object + * \param other The other Object + */ + TVM_FFI_INLINE void swap(AnyView& other) noexcept { std::swap(data_, other.data_); } + /*! \return the internal type index */ + TVM_FFI_INLINE int32_t type_index() const noexcept { return data_.type_index; } + // default constructors + AnyView() { + data_.type_index = TypeIndex::kTVMFFINone; + data_.v_int64 = 0; + } + ~AnyView() = default; + // constructors from any view + AnyView(const AnyView&) = default; + AnyView& operator=(const AnyView&) = default; + AnyView(AnyView&& other) : data_(other.data_) { + other.data_.type_index = TypeIndex::kTVMFFINone; + other.data_.v_int64 = 0; + } + TVM_FFI_INLINE AnyView& operator=(AnyView&& other) { + // copy-and-swap idiom + AnyView(std::move(other)).swap(*this); // NOLINT(*) + return *this; + } + // constructor from general types + template ::convert_enabled>> + AnyView(const T& other) { // NOLINT(*) + TypeTraits::CopyToAnyView(other, &data_); + } + template ::convert_enabled>> + TVM_FFI_INLINE AnyView& operator=(const T& other) { // NOLINT(*) + // copy-and-swap idiom + AnyView(other).swap(*this); // NOLINT(*) + return *this; + } + + template ::convert_enabled>> + TVM_FFI_INLINE std::optional as() const { + return TypeTraits::TryConvertFromAnyView(&data_); + } + + template ::convert_enabled>> + TVM_FFI_INLINE T cast() const { + std::optional opt = TypeTraits::TryConvertFromAnyView(&data_); + if (!opt.has_value()) { + TVM_FFI_THROW(TypeError) << "Cannot convert from type `" + << TypeTraits::GetMismatchTypeInfo(&data_) << "` to `" + << TypeTraits::TypeStr() << "`"; + } + return *std::move(opt); + } + + /* + * \brief Shortcut of as Object to cast to a const pointer when T is an Object. + * + * \tparam T The object type. + * \return The requested pointer, returns nullptr if type mismatches. + */ + template >> + TVM_FFI_INLINE const T* as() const { + return this->as().value_or(nullptr); + } + // comparison with nullptr + TVM_FFI_INLINE bool operator==(std::nullptr_t) const noexcept { + return data_.type_index == TypeIndex::kTVMFFINone; + } + TVM_FFI_INLINE bool operator!=(std::nullptr_t) const noexcept { + return data_.type_index != TypeIndex::kTVMFFINone; + } + /*! + * \brief Get the type key of the Any + * \return The type key of the Any + */ + TVM_FFI_INLINE std::string GetTypeKey() const { return TypeIndexToTypeKey(data_.type_index); } + // The following functions are only used for testing purposes + /*! + * \return The underlying supporting data of any view + * \note This function is used only for testing purposes. + */ + TVM_FFI_INLINE TVMFFIAny CopyToTVMFFIAny() const { return data_; } + /*! + * \return Create an AnyView from TVMFFIAny + * \param data the underlying ffi data. + */ + static TVM_FFI_INLINE AnyView CopyFromTVMFFIAny(TVMFFIAny data) { + AnyView view; + view.data_ = data; + return view; + } +}; + +namespace details { +/*! + * \brief Helper function to inplace convert any view to any. + * \param data The pointer that represents the format as any view. + * \param extra_any_bytes Indicate that the data may contain extra bytes following + * the TVMFFIAny data structure. This is reserved for future possible optimizations + * of small-string and extended any object. + */ +TVM_FFI_INLINE void InplaceConvertAnyViewToAny(TVMFFIAny* data, + [[maybe_unused]] size_t extra_any_bytes = 0) { + if (data->type_index >= TVMFFITypeIndex::kTVMFFIStaticObjectBegin) { + details::ObjectUnsafe::IncRefObjectHandle(data->v_obj); + } else if (data->type_index >= TypeIndex::kTVMFFIRawStr) { + if (data->type_index == TypeIndex::kTVMFFIRawStr) { + // convert raw string to owned string object + String temp(data->v_c_str); + data->type_index = TypeIndex::kTVMFFIStr; + data->v_obj = details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(temp)); + } else if (data->type_index == TypeIndex::kTVMFFIByteArrayPtr) { + // convert byte array to owned bytes object + Bytes temp(*static_cast(data->v_ptr)); + data->type_index = TypeIndex::kTVMFFIBytes; + data->v_obj = details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(temp)); + } else if (data->type_index == TypeIndex::kTVMFFIObjectRValueRef) { + // convert rvalue ref to owned object + Object** obj_addr = static_cast(data->v_ptr); + TVM_FFI_ICHECK(obj_addr[0] != nullptr) << "RValueRef already moved"; + ObjectRef temp(details::ObjectUnsafe::ObjectPtrFromOwned(obj_addr[0])); + // set the rvalue ref to nullptr to avoid double move + obj_addr[0] = nullptr; + data->type_index = temp->type_index(); + data->v_obj = details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(temp)); + } + } +} +} // namespace details + +/*! + * \brief Managed Any that takes strong reference to a value. + * + * \note Develooper invariance: the TVMFFIAny data_ + * in the Any can be safely used in AnyView. + */ +class Any { + protected: + /*! \brief The underlying backing data of the any object */ + TVMFFIAny data_; + + public: + /*! + * \brief Reset any to None + */ + TVM_FFI_INLINE void reset() { + if (data_.type_index >= TVMFFITypeIndex::kTVMFFIStaticObjectBegin) { + details::ObjectUnsafe::DecRefObjectHandle(data_.v_obj); + } + data_.type_index = TVMFFITypeIndex::kTVMFFINone; + data_.v_int64 = 0; + } + /*! + * \brief Swap this array with another Object + * \param other The other Object + */ + TVM_FFI_INLINE void swap(Any& other) noexcept { std::swap(data_, other.data_); } + /*! \return the internal type index */ + TVM_FFI_INLINE int32_t type_index() const noexcept { return data_.type_index; } + // default constructors + Any() { + data_.type_index = TypeIndex::kTVMFFINone; + data_.v_int64 = 0; + } + ~Any() { this->reset(); } + // constructors from Any + Any(const Any& other) : data_(other.data_) { + if (data_.type_index >= TypeIndex::kTVMFFIStaticObjectBegin) { + details::ObjectUnsafe::IncRefObjectHandle(data_.v_obj); + } + } + Any(Any&& other) : data_(other.data_) { + other.data_.type_index = TypeIndex::kTVMFFINone; + other.data_.v_int64 = 0; + } + TVM_FFI_INLINE Any& operator=(const Any& other) { + // copy-and-swap idiom + Any(other).swap(*this); // NOLINT(*) + return *this; + } + TVM_FFI_INLINE Any& operator=(Any&& other) { + // copy-and-swap idiom + Any(std::move(other)).swap(*this); // NOLINT(*) + return *this; + } + // convert from/to AnyView + Any(const AnyView& other) : data_(other.data_) { // NOLINT(*) + details::InplaceConvertAnyViewToAny(&data_); + } + TVM_FFI_INLINE Any& operator=(const AnyView& other) { + // copy-and-swap idiom + Any(other).swap(*this); // NOLINT(*) + return *this; + } + /*! \brief Any can be converted to AnyView in zero cost. */ + operator AnyView() const { return AnyView::CopyFromTVMFFIAny(data_); } + // constructor from general types + template ::convert_enabled>> + Any(T other) { // NOLINT(*) + TypeTraits::MoveToAny(std::move(other), &data_); + } + template ::convert_enabled>> + TVM_FFI_INLINE Any& operator=(T other) { // NOLINT(*) + // copy-and-swap idiom + Any(std::move(other)).swap(*this); // NOLINT(*) + return *this; + } + + template ::convert_enabled || std::is_same_v>> + TVM_FFI_INLINE std::optional as() const { + if constexpr (std::is_same_v) { + return *this; + } else { + return TypeTraits::TryConvertFromAnyView(&data_); + } + } + + /* + * \brief Shortcut of as Object to cast to a const pointer when T is an Object. + * + * \tparam T The object type. + * \return The requested pointer, returns nullptr if type mismatches. + */ + template >> + TVM_FFI_INLINE const T* as() const { + return this->as().value_or(nullptr); + } + + template ::convert_enabled>> + TVM_FFI_INLINE T cast() const& { + std::optional opt = TypeTraits::TryConvertFromAnyView(&data_); + if (!opt.has_value()) { + TVM_FFI_THROW(TypeError) << "Cannot convert from type `" + << TypeTraits::GetMismatchTypeInfo(&data_) << "` to `" + << TypeTraits::TypeStr() << "`"; + } + return *std::move(opt); + } + + template ::storage_enabled>> + TVM_FFI_INLINE T cast() && { + if (TypeTraits::CheckAnyStorage(&data_)) { + return TypeTraits::MoveFromAnyStorageAfterCheck(&data_); + } + // slow path, try to do fallback convert + std::optional opt = TypeTraits::TryConvertFromAnyView(&data_); + if (!opt.has_value()) { + TVM_FFI_THROW(TypeError) << "Cannot convert from type `" + << TypeTraits::GetMismatchTypeInfo(&data_) << "` to `" + << TypeTraits::TypeStr() << "`"; + } + return *std::move(opt); + } + + /* + * \brief Check if the two Any are same type and value in shallow comparison. + * \param other The other Any + * \return True if the two Any are same type and value, false otherwise. + */ + TVM_FFI_INLINE bool same_as(const Any& other) const noexcept { + return data_.type_index == other.data_.type_index && data_.v_int64 == other.data_.v_int64; + } + + /* + * \brief Check if any and ObjectRef are same type and value in shallow comparison. + * \param other The other ObjectRef + * \return True if the two Any are same type and value, false otherwise. + */ + TVM_FFI_INLINE bool same_as(const ObjectRef& other) const noexcept { + if (other.get() != nullptr) { + return (data_.type_index == other->type_index() && + reinterpret_cast(data_.v_obj) == other.get()); + } else { + return data_.type_index == TypeIndex::kTVMFFINone; + } + } + + TVM_FFI_INLINE bool operator==(std::nullptr_t) const noexcept { + return data_.type_index == TypeIndex::kTVMFFINone; + } + TVM_FFI_INLINE bool operator!=(std::nullptr_t) const noexcept { + return data_.type_index != TypeIndex::kTVMFFINone; + } + + /*! + * \brief Get the type key of the Any + * \return The type key of the Any + */ + TVM_FFI_INLINE std::string GetTypeKey() const { return TypeIndexToTypeKey(data_.type_index); } + + friend struct details::AnyUnsafe; + friend struct AnyHash; + friend struct AnyEqual; +}; + +// layout assert to ensure we can freely cast between the two types +static_assert(sizeof(AnyView) == sizeof(TVMFFIAny)); +static_assert(sizeof(Any) == sizeof(TVMFFIAny)); + +namespace details { + +template +struct Type2Str { + static std::string v() { return TypeTraitsNoCR::TypeStr(); } +}; + +template <> +struct Type2Str { + static std::string v() { return "Any"; } +}; + +template <> +struct Type2Str { + static std::string v() { return "Any"; } +}; + +template <> +struct Type2Str { + static std::string v() { return "AnyView"; } +}; + +template <> +struct Type2Str { + static std::string v() { return "AnyView"; } +}; + +template <> +struct Type2Str { + static std::string v() { return "void"; } +}; + +// Extra unsafe method to help any manipulation +struct AnyUnsafe : public ObjectUnsafe { + // FFI related operations + static TVM_FFI_INLINE TVMFFIAny MoveAnyToTVMFFIAny(Any&& any) { + TVMFFIAny result = any.data_; + any.data_.type_index = TypeIndex::kTVMFFINone; + any.data_.v_int64 = 0; + return result; + } + + static TVM_FFI_INLINE Any MoveTVMFFIAnyToAny(TVMFFIAny&& data) { + Any any; + any.data_ = data; + data.type_index = TypeIndex::kTVMFFINone; + data.v_int64 = 0; + return any; + } + + template + static TVM_FFI_INLINE bool CheckAnyStorage(const Any& ref) { + return TypeTraits::CheckAnyStorage(&(ref.data_)); + } + + template + static TVM_FFI_INLINE T CopyFromAnyStorageAfterCheck(const Any& ref) { + if constexpr (!std::is_same_v) { + return TypeTraits::CopyFromAnyStorageAfterCheck(&(ref.data_)); + } else { + return ref; + } + } + + static TVM_FFI_INLINE Object* ObjectPtrFromAnyAfterCheck(const Any& ref) { + return reinterpret_cast(ref.data_.v_obj); + } + + static TVM_FFI_INLINE const TVMFFIAny* TVMFFIAnyPtrFromAny(const Any& ref) { + return &(ref.data_); + } + + template + static TVM_FFI_INLINE std::string GetMismatchTypeInfo(const Any& ref) { + return TypeTraits::GetMismatchTypeInfo(&(ref.data_)); + } +}; +} // namespace details + +/*! \brief String-aware Any equal functor */ +struct AnyHash { + /*! + * \brief Calculate the hash code of an Any + * \param a The given Any + * \return Hash code of a, string hash for strings and pointer address otherwise. + */ + uint64_t operator()(const Any& src) const { + uint64_t val_hash = [&]() -> uint64_t { + if (src.data_.type_index == TypeIndex::kTVMFFIStr || + src.data_.type_index == TypeIndex::kTVMFFIBytes) { + const BytesObjBase* src_str = + details::AnyUnsafe::CopyFromAnyStorageAfterCheck(src); + return details::StableHashBytes(src_str->data, src_str->size); + } else { + return src.data_.v_uint64; + } + }(); + return details::StableHashCombine(src.data_.type_index, val_hash); + } +}; + +/*! \brief String-aware Any hash functor */ +struct AnyEqual { + /*! + * \brief Check if the two Any are equal + * \param lhs left operand. + * \param rhs right operand + * \return String equality if both are strings, pointer address equality otherwise. + */ + bool operator()(const Any& lhs, const Any& rhs) const { + if (lhs.data_.type_index != rhs.data_.type_index) return false; + // byte equivalence + if (lhs.data_.v_int64 == rhs.data_.v_int64) return true; + // specialy handle string hash + if (lhs.data_.type_index == TypeIndex::kTVMFFIStr || + lhs.data_.type_index == TypeIndex::kTVMFFIBytes) { + const BytesObjBase* lhs_str = + details::AnyUnsafe::CopyFromAnyStorageAfterCheck(lhs); + const BytesObjBase* rhs_str = + details::AnyUnsafe::CopyFromAnyStorageAfterCheck(rhs); + return Bytes::memncmp(lhs_str->data, rhs_str->data, lhs_str->size, rhs_str->size) == 0; + } + return false; + } +}; + +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_ANY_H_ diff --git a/ffi/include/tvm/ffi/base_details.h b/ffi/include/tvm/ffi/base_details.h new file mode 100644 index 000000000000..18cc3ecb726f --- /dev/null +++ b/ffi/include/tvm/ffi/base_details.h @@ -0,0 +1,274 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/ffi/base_details.h + * \brief Internal detail utils that can be used by files in tvm/ffi. + * \note details header are for internal use only + * and not to be directly used by user. + */ +#ifndef TVM_FFI_BASE_DETAILS_H_ +#define TVM_FFI_BASE_DETAILS_H_ + +#include +#include + +#include +#include + +#if defined(_MSC_VER) +#ifndef WIN32_LEAN_AND_MEAN +#define WIN32_LEAN_AND_MEAN +#endif + +#ifndef NOMINMAX +#define NOMINMAX +#endif + +#include + +#ifdef ERROR +#undef ERROR +#endif + +#endif + +#if defined(_MSC_VER) +#define TVM_FFI_INLINE __forceinline +#else +#define TVM_FFI_INLINE inline __attribute__((always_inline)) +#endif + +/*! + * \brief Macro helper to force a function not to be inlined. + * It is only used in places that we know not inlining is good, + * e.g. some logging functions. + */ +#if defined(_MSC_VER) +#define TVM_FFI_NO_INLINE __declspec(noinline) +#else +#define TVM_FFI_NO_INLINE __attribute__((noinline)) +#endif + +#if defined(_MSC_VER) +#define TVM_FFI_UNREACHABLE() __assume(false) +#else +#define TVM_FFI_UNREACHABLE() __builtin_unreachable() +#endif + +/*! \brief helper macro to suppress unused warning */ +#if defined(__GNUC__) +#define TVM_FFI_ATTRIBUTE_UNUSED __attribute__((unused)) +#else +#define TVM_FFI_ATTRIBUTE_UNUSED +#endif + +#define TVM_FFI_STR_CONCAT_(__x, __y) __x##__y +#define TVM_FFI_STR_CONCAT(__x, __y) TVM_FFI_STR_CONCAT_(__x, __y) + +#if defined(__GNUC__) || defined(__clang__) +#define TVM_FFI_FUNC_SIG __PRETTY_FUNCTION__ +#elif defined(_MSC_VER) +#define TVM_FFI_FUNC_SIG __FUNCSIG__ +#else +#define TVM_FFI_FUNC_SIG __func__ +#endif + +/* + * \brief Define the default copy/move constructor and assign operator + * \param TypeName The class typename. + */ +#define TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \ + TypeName(const TypeName& other) = default; \ + TypeName(TypeName&& other) = default; \ + TypeName& operator=(const TypeName& other) = default; \ + TypeName& operator=(TypeName&& other) = default; + +/** + * \brief marks the begining of a C call that logs exception + */ +#define TVM_FFI_LOG_EXCEPTION_CALL_BEGIN() \ + try { \ + (void)0 + +/*! + * \brief Marks the end of a C call that logs exception + */ +#define TVM_FFI_LOG_EXCEPTION_CALL_END(Name) \ + } \ + catch (const std::exception& err) { \ + std::cerr << "Exception caught during " << #Name << ":\n" << err.what() << std::endl; \ + exit(-1); \ + } + +/*! + * \brief Clear the padding parts so we can safely use v_int64 for hash + * and equality check even when the value stored is a pointer. + * + * This macro is used to clear the padding parts for hash and equality check + * in 32bit platform. + */ +#define TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result) \ + if constexpr (sizeof(result->v_obj) != sizeof(result->v_int64)) { \ + result->v_int64 = 0; \ + } + +namespace tvm { +namespace ffi { +namespace details { + +/********** Atomic Operations *********/ + +TVM_FFI_INLINE int32_t AtomicIncrementRelaxed(int32_t* ptr) { +#ifdef _MSC_VER + return _InterlockedIncrement(reinterpret_cast(ptr)) - 1; // NOLINT(*) +#else + return __atomic_fetch_add(ptr, 1, __ATOMIC_RELAXED); +#endif +} + +TVM_FFI_INLINE int32_t AtomicDecrementRelAcq(int32_t* ptr) { +#ifdef _MSC_VER + return _InterlockedDecrement(reinterpret_cast(ptr)) + 1; // NOLINT(*) +#else + return __atomic_fetch_sub(ptr, 1, __ATOMIC_ACQ_REL); +#endif +} + +TVM_FFI_INLINE int32_t AtomicLoadRelaxed(const int32_t* ptr) { + int32_t* raw_ptr = const_cast(ptr); +#ifdef _MSC_VER + // simply load the variable ptr out + return (reinterpret_cast(raw_ptr))[0]; // NOLINT(*) +#else + return __atomic_load_n(raw_ptr, __ATOMIC_RELAXED); +#endif +} + +// for each iterator +template +struct for_each_dispatcher { + template + static void run(const F& f, T&& value, Args&&... args) { // NOLINT(*) + f(I, std::forward(value)); + for_each_dispatcher::run(f, std::forward(args)...); + } +}; + +template +struct for_each_dispatcher { + static void run(const F&) {} // NOLINT(*) +}; + +template +void for_each(const F& f, Args&&... args) { // NOLINT(*) + for_each_dispatcher::run(f, std::forward(args)...); +} + +/*! + * \brief hash an object and combines uint64_t key with previous keys + * + * This hash function is stable across platforms. + * + * \param key The left operand. + * \param value The right operand. + * \return the combined result. + */ +template ::value, bool> = true> +TVM_FFI_INLINE uint64_t StableHashCombine(uint64_t key, const T& value) { + // XXX: do not use std::hash in this function. This hash must be stable + // across different platforms and std::hash is implementation dependent. + return key ^ (uint64_t(value) + 0x9e3779b9 + (key << 6) + (key >> 2)); +} + +/*! + * \brief Hash the binary bytes + * \param data The data pointer + * \param size The size of the bytes. + * \return the hash value. + */ +TVM_FFI_INLINE uint64_t StableHashBytes(const char* data, size_t size) { + const constexpr uint64_t kMultiplier = 1099511628211ULL; + const constexpr uint64_t kMod = 2147483647ULL; + union Union { + uint8_t a[8]; + uint64_t b; + } u; + static_assert(sizeof(Union) == sizeof(uint64_t), "sizeof(Union) != sizeof(uint64_t)"); + const char* it = data; + const char* end = it + size; + uint64_t result = 0; + for (; it + 8 <= end; it += 8) { + if constexpr (TVM_FFI_IO_NO_ENDIAN_SWAP) { + u.a[0] = it[0]; + u.a[1] = it[1]; + u.a[2] = it[2]; + u.a[3] = it[3]; + u.a[4] = it[4]; + u.a[5] = it[5]; + u.a[6] = it[6]; + u.a[7] = it[7]; + } else { + u.a[0] = it[7]; + u.a[1] = it[6]; + u.a[2] = it[5]; + u.a[3] = it[4]; + u.a[4] = it[3]; + u.a[5] = it[2]; + u.a[6] = it[1]; + u.a[7] = it[0]; + } + result = (result * kMultiplier + u.b) % kMod; + } + if (it < end) { + u.b = 0; + uint8_t* a = u.a; + if (it + 4 <= end) { + a[0] = it[0]; + a[1] = it[1]; + a[2] = it[2]; + a[3] = it[3]; + it += 4; + a += 4; + } + if (it + 2 <= end) { + a[0] = it[0]; + a[1] = it[1]; + it += 2; + a += 2; + } + if (it + 1 <= end) { + a[0] = it[0]; + it += 1; + a += 1; + } + if constexpr (!TVM_FFI_IO_NO_ENDIAN_SWAP) { + std::swap(u.a[0], u.a[7]); + std::swap(u.a[1], u.a[6]); + std::swap(u.a[2], u.a[5]); + std::swap(u.a[3], u.a[4]); + } + result = (result * kMultiplier + u.b) % kMod; + } + return result; +} + +} // namespace details +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_BASE_DETAILS_H_ diff --git a/ffi/include/tvm/ffi/c_api.h b/ffi/include/tvm/ffi/c_api.h new file mode 100644 index 000000000000..1d495d9c5e96 --- /dev/null +++ b/ffi/include/tvm/ffi/c_api.h @@ -0,0 +1,718 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * \file tvm/ffi/c_api.h + * \brief This file defines the C convention of the FFI convention + */ +#ifndef TVM_FFI_C_API_H_ +#define TVM_FFI_C_API_H_ + +#include +#include + +#if !defined(TVM_FFI_DLL) && defined(__EMSCRIPTEN__) +#include +#define TVM_FFI_DLL EMSCRIPTEN_KEEPALIVE +#endif +#if !defined(TVM_FFI_DLL) && defined(_MSC_VER) +#ifdef TVM_FFI_EXPORTS +#define TVM_FFI_DLL __declspec(dllexport) +#else +#define TVM_FFI_DLL __declspec(dllimport) +#endif +#endif +#ifndef TVM_FFI_DLL +#define TVM_FFI_DLL __attribute__((visibility("default"))) +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +#ifdef __cplusplus +enum TVMFFITypeIndex : int32_t { +#else +typedef enum { +#endif + // [Section] On-stack POD and special types: [0, kTVMFFIStaticObjectBegin) + // N.B. `kTVMFFIRawStr` is a string backed by a `\0`-terminated char array, + // which is not owned by TVMFFIAny. It is required that the following + // invariant holds: + // - `Any::type_index` is never `kTVMFFIRawStr` + // - `AnyView::type_index` can be `kTVMFFIRawStr` + // + /* + * \brief The root type of all FFI objects. + * + * We include it so TypeIndex captures all possible runtime values. + * `kTVMFFIAny` code will never appear in Any::type_index. + * However, it may appear in field annotations during reflection. + */ + kTVMFFIAny = -1, + /*! \brief None/nullptr value */ + kTVMFFINone = 0, + /*! \brief POD int value */ + kTVMFFIInt = 1, + /*! \brief POD bool value */ + kTVMFFIBool = 2, + /*! \brief POD float value */ + kTVMFFIFloat = 3, + /*! \brief Opaque pointer object */ + kTVMFFIOpaquePtr = 4, + /*! \brief DLDataType */ + kTVMFFIDataType = 5, + /*! \brief DLDevice */ + kTVMFFIDevice = 6, + /*! \brief DLTensor* */ + kTVMFFIDLTensorPtr = 7, + /*! \brief const char**/ + kTVMFFIRawStr = 8, + /*! \brief TVMFFIByteArray* */ + kTVMFFIByteArrayPtr = 9, + /*! \brief R-value reference to ObjectRef */ + kTVMFFIObjectRValueRef = 10, + /*! \brief Start of statically defined objects. */ + kTVMFFIStaticObjectBegin = 64, + /*! + * \brief Object, all objects starts with TVMFFIObject as its header. + * \note We will also add other fields + */ + kTVMFFIObject = 64, + /*! + * \brief String object, layout = { TVMFFIObject, TVMFFIByteArray, ... } + */ + kTVMFFIStr = 65, + /*! + * \brief Bytes object, layout = { TVMFFIObject, TVMFFIByteArray, ... } + */ + kTVMFFIBytes = 66, + /*! \brief Error object. */ + kTVMFFIError = 67, + /*! \brief Function object. */ + kTVMFFIFunction = 68, + /*! \brief Array object. */ + kTVMFFIArray = 69, + /*! \brief Map object. */ + kTVMFFIMap = 70, + /*! + * \brief Shape object, layout = { TVMFFIObject, { const int64_t*, size_t }, ... } + */ + kTVMFFIShape = 71, + /*! + * \brief NDArray object, layout = { TVMFFIObject, DLTensor, ... } + */ + kTVMFFINDArray = 72, + /*! \brief Runtime module object. */ + kTVMFFIModule = 73, + kTVMFFIStaticObjectEnd, + // [Section] Dynamic Boxed: [kTVMFFIDynObjectBegin, +oo) + /*! \brief Start of type indices that are allocated at runtime. */ + kTVMFFIDynObjectBegin = 128 +#ifdef __cplusplus +}; +#else +} TVMFFITypeIndex; +#endif + +/*! \brief Handle to Object from C API's pov */ +typedef void* TVMFFIObjectHandle; + +/*! + * \brief C-based type of all FFI object header that allocates on heap. + * \note TVMFFIObject and TVMFFIAny share the common type_index header + */ +typedef struct TVMFFIObject { + /*! + * \brief type index of the object. + * \note The type index of Object and Any are shared in FFI. + */ + int32_t type_index; + /*! \brief Reference counter of the object. */ + int32_t ref_counter; + union { + /*! \brief Deleter to be invoked when reference counter goes to zero. */ + void (*deleter)(struct TVMFFIObject* self); + /*! + * \brief auxilary field to TVMFFIObject is always 8 bytes aligned. + * \note This helps us to ensure cross platform compatibility. + */ + int64_t __ensure_align; + }; +} TVMFFIObject; + +/*! + * \brief C-based type of all on stack Any value. + * + * Any value can hold on stack values like int, + * as well as reference counted pointers to object. + */ +typedef struct TVMFFIAny { + /*! + * \brief type index of the object. + * \note The type index of Object and Any are shared in FFI. + */ + int32_t type_index; + /*! + * \brief length for on-stack Any object, such as small-string + * \note This field is reserved for future compact. + */ + int32_t small_len; + union { // 8 bytes + int64_t v_int64; // integers + double v_float64; // floating-point numbers + void* v_ptr; // typeless pointers + const char* v_c_str; // raw C-string + TVMFFIObject* v_obj; // ref counted objects + DLDataType v_dtype; // data type + DLDevice v_device; // device + char v_bytes[8]; // small string + char32_t v_char32[2]; // small UCS4 string and Unicode + uint64_t v_uint64; // uint64 repr mainly used for hashing + }; +} TVMFFIAny; + +/*! + * \brief Byte array data structure used by String and Bytes. + * + * String and Bytes object layout = { TVMFFIObject, TVMFFIByteArray, ... } + * + * \note This byte array data structure layout differs in 32/64 bit platforms. + * as size_t equals to the size of the pointer, use this convetion to + * be consistent with std::string and also avoid need to calculate padding + * for the size field on 32-bit platforms. + * The FFI binding should be careful when treating this ABI. + */ +typedef struct { + const char* data; + size_t size; +} TVMFFIByteArray; + +/*! + * \brief Shape cell used in shape object following header. + */ +typedef struct { + const int64_t* data; + size_t size; +} TVMFFIShapeCell; + +/*! + * \brief Error cell used in error object following header. + */ +typedef struct { + /*! \brief The kind of the error. */ + TVMFFIByteArray kind; + /*! \brief The message of the error. */ + TVMFFIByteArray message; + /*! + * \brief The traceback of the error. + */ + TVMFFIByteArray traceback; + /*! + * \brief Function handle to update the traceback of the error. + * \param self The self object handle. + * \param traceback The traceback to update. + */ + void (*update_traceback)(TVMFFIObjectHandle self, const TVMFFIByteArray* traceback); +} TVMFFIErrorCell; + +/*! + * \brief Type that defines C-style safe call convention + * + * Safe call explicitly catches exception on function boundary. + * + * \param self The function handle + * \param num_args Number of input arguments + * \param args The input arguments to the call. + * \param result Store output result. + * + * IMPORTANT: caller must initialize result->type_index to be kTVMFFINone, + * or any other value smaller than kTVMFFIStaticObjectBegin. + * + * \return The call returns 0 if call is successful. + * It returns non-zero value if there is an error. + * + * Possible return error of the API functions: + * * 0: success + * * -1: error happens, can be retrieved by TVMFFIErrorMoveFromRaised + * * -2: a frontend error occurred and recorded in the frontend. + * + * \note We decided to leverage TVMFFIErrorMoveFromRaised and TVMFFIErrorSetRaised + * for C function error propagation. This design choice, while + * introducing a dependency for TLS runtime, simplifies error + * propgation in chains of calls in compiler codegen. + * As we do not need to propagate error through argument but simply + * set them in the runtime environment. + * + * \sa TVMFFIErrorMoveFromRaised + * \sa TVMFFIErrorSetRaised + * \sa TVMFFIErrorSetRaisedByCStr + */ +typedef int (*TVMFFISafeCallType)(void* self, const TVMFFIAny* args, int32_t num_args, + TVMFFIAny* result); + +/*! + * \brief Object cell for function object following header. + */ +typedef struct { + /*! \brief A C API compatible call with exception catching. */ + TVMFFISafeCallType safe_call; +} TVMFFIFunctionCell; + +/*! + * \brief Getter that can take address of a field and set the result. + * \param field The raw address of the field. + * \param result Stores the result. + */ +typedef int (*TVMFFIFieldGetter)(void* field, TVMFFIAny* result); + +/*! + * \brief Getter that can take address of a field and set to value. + * \param field The raw address of the field. + * \param value The value to set. + */ +typedef int (*TVMFFIFieldSetter)(void* field, const TVMFFIAny* value); + +/*! + * \brief Information support for optional object reflection. + */ +typedef struct { + /*! \brief The name of the field. */ + TVMFFIByteArray name; + /*! + * \brief Records the static type kind of the field. + * + * Possible values: + * + * - TVMFFITypeIndex::kTVMFFIObject for general objects + * - The value is nullable when kTVMFFIObject is chosen + * - static object type kinds such as Map, Dict, String + * - POD type index + * - TVMFFITypeIndex::kTVMFFIAny if we don't have specialized info + * about the field. + * + * \note This information is helpful in designing serializer + * of the field. As it helps to narrow down the type of the + * object. It also helps to provide opportunities to enable + * short-cut access to the field. + */ + int32_t field_static_type_index; + /*! + * \brief Mark whether field is readonly. + */ + int32_t readonly; + /*! + * \brief Byte offset of the field. + */ + int64_t byte_offset; + /*! \brief The getter to access the field. */ + TVMFFIFieldGetter getter; + /*! \brief The setter to access the field. */ + TVMFFIFieldSetter setter; +} TVMFFIFieldInfo; + +/*! + * \brief Method information that can appear in reflection table. + */ +typedef struct { + /*! \brief The name of the field. */ + TVMFFIByteArray name; + /*! + * \brief The method wrapped as Function + * \note The first argument to the method is always the self. + */ + TVMFFIObjectHandle method; +} TVMFFIMethodInfo; + +/*! + * \brief Runtime type information for object type checking. + */ +typedef struct { + /*! + *\brief The runtime type index, + * It can be allocated during runtime if the type is dynamic. + */ + int32_t type_index; + /*! \brief number of parent types in the type hierachy. */ + int32_t type_depth; + /*! \brief the unique type key to identify the type. */ + TVMFFIByteArray type_key; + /*! \brief Cached hash value of the type key, used for consistent structural hashing. */ + uint64_t type_key_hash; + /*! + * \brief type_acenstors[depth] stores the type_index of the acenstors at depth level + * \note To keep things simple, we do not allow multiple inheritance so the + * hieracy stays as a tree + */ + const int32_t* type_acenstors; + /*! \brief number of reflection accessible fields. */ + int32_t num_fields; + /*! \brief number of reflection acccesible methods. */ + int32_t num_methods; + /*! \brief The reflection field information. */ + TVMFFIFieldInfo* fields; + /*! \brief The reflection method. */ + TVMFFIMethodInfo* methods; +} TVMFFITypeInfo; + +//------------------------------------------------------------ +// Section: User APIs to interact with the FFI +//------------------------------------------------------------ +/*! + * \brief Free an object handle by decreasing reference + * \param obj The object handle. + * \note Internally we decrease the reference counter of the object. + * The object will be freed when every reference to the object are removed. + * \return 0 when success, nonzero when failure happens + */ +TVM_FFI_DLL int TVMFFIObjectFree(TVMFFIObjectHandle obj); + +/*! + * \brief Convert type key to type index. + * \param type_key The key of the type. + * \param out_tindex the corresponding type index. + * \return 0 when success, nonzero when failure happens + */ +TVM_FFI_DLL int TVMFFITypeKeyToIndex(const TVMFFIByteArray* type_key, int32_t* out_tindex); + +//----------------------------------------------------------------------- +// Section: Function calling APIs and support API for func implementation +//----------------------------------------------------------------------- +/*! + * \brief Create a FFIFunc by passing in callbacks from C callback. + * + * The registered function then can be pulled by the backend by the name. + * + * \param self The resource handle of the C callback. + * \param safe_call The C callback implementation + * \param deleter deleter to recycle + * \param out The output of the function. + * \return 0 when success, nonzero when failure happens + */ +TVM_FFI_DLL int TVMFFIFunctionCreate(void* self, TVMFFISafeCallType safe_call, + void (*deleter)(void* self), TVMFFIObjectHandle* out); + +/*! + * \brief Convert a AnyView to an owned Any. + * \param any The AnyView to convert. + * \param out The output Any, must be an empty object + * \return 0 when success, nonzero when failure happens + */ +TVM_FFI_DLL int TVMFFIAnyViewToOwnedAny(const TVMFFIAny* any_view, TVMFFIAny* out); + +/*! + * \brief Call a FFIFunc by passing in arguments. + * + * \param func The resource handle of the C callback. + * \param args The input arguments to the call. + * \param num_args The number of input arguments. + * \param result The output result, caller must ensure result->type_index is set to kTVMFFINone. + * \return 0 when success, nonzero when failure happens + */ +TVM_FFI_DLL int TVMFFIFunctionCall(TVMFFIObjectHandle func, TVMFFIAny* args, int32_t num_args, + TVMFFIAny* result); + +/*! + * \brief Register the function to runtime's global table. + * + * The registered function then can be pulled by the backend by the name. + * + * \param name The name of the function. + * \param f The function to be registered. + * \param override Whether allow override already registered function. + * \return 0 when success, nonzero when failure happens + */ +TVM_FFI_DLL int TVMFFIFunctionSetGlobal(const TVMFFIByteArray* name, TVMFFIObjectHandle f, + int override); + +/*! + * \brief Get a global function. + * + * \param name The name of the function. + * \param out the result function pointer, NULL if it does not exist. + * \return 0 when success, nonzero when failure happens + */ +TVM_FFI_DLL int TVMFFIFunctionGetGlobal(const TVMFFIByteArray* name, TVMFFIObjectHandle* out); + +/*! + * \brief Move the last error from the environment to result. + * + * \param result The result error. + * + * \note This function clears the error stored in the TLS. + */ +TVM_FFI_DLL void TVMFFIErrorMoveFromRaised(TVMFFIObjectHandle* result); + +/*! + * \brief Set raised error in TLS, which can be fetched by TVMFFIErrorMoveFromRaised. + * + * \param error The error object handle + */ +TVM_FFI_DLL void TVMFFIErrorSetRaised(TVMFFIObjectHandle error); + +/*! + * \brief Set raised error in TLS, which can be fetched by TVMFFIMoveFromRaised. + * + * \param kind The kind of the error. + * \param message The error message. + * \note This is a convenient method for C API side to set error directly from string. + */ +TVM_FFI_DLL void TVMFFIErrorSetRaisedByCStr(const char* kind, const char* message); + +/*! + * \brief Create an initial error object. + * + * \param kind The kind of the error. + * \param message The error message. + * \param traceback The traceback of the error. + * \return The created error object handle. + * \note This function is different from other functions as it is used in error handling loop. + * So we do not follow normal error handling patterns via returning error code. + */ +TVM_FFI_DLL TVMFFIObjectHandle TVMFFIErrorCreate(const TVMFFIByteArray* kind, + const TVMFFIByteArray* message, + const TVMFFIByteArray* traceback); + +/*! + * \brief Check if there are any signals raised in the surrounding env. + * \return 0 when success, nonzero when failure happens + * \note Under python this function redirects to PyErr_CheckSignals + */ +TVM_FFI_DLL int TVMFFIEnvCheckSignals(); + +/*! + * \brief Register a symbol into the from the surrounding env. + * \param name The name of the symbol. + * \param symbol The symbol to register. + * \return 0 when success, nonzero when failure happens + */ +TVM_FFI_DLL int TVMFFIEnvRegisterCAPI(const TVMFFIByteArray* name, void* symbol); + +//------------------------------------------------------------ +// Section: Type reflection support APIs +//------------------------------------------------------------ +/*! + * \brief Register type field information for rutnime reflection. + * \param type_index The type index + * \param info The field info to be registered. + * \return 0 when success, nonzero when failure happens + */ +TVM_FFI_DLL int TVMFFIRegisterTypeField(int32_t type_index, const TVMFFIFieldInfo* info); + +//------------------------------------------------------------ +// Section: DLPack support APIs +//------------------------------------------------------------ +/*! + * \brief Produce a managed NDArray from a DLPack tensor. + * \param from The source DLPack tensor. + * \param require_alignment The minimum alignment requored of the data + byte_offset. + * \param require_contiguous Boolean flag indicating if we need to check for contiguity. + * \param out The output NDArray handle. + * \return 0 when success, nonzero when failure happens + */ +TVM_FFI_DLL int TVMFFINDArrayFromDLPack(DLManagedTensor* from, int32_t require_alignment, + int32_t require_contiguous, TVMFFIObjectHandle* out); + +/*! + * \brief Produce a DLMangedTensor from the array that shares data memory with the array. + * \param from The source array. + * \param out The DLManagedTensor handle. + * \return 0 when success, nonzero when failure happens + */ +TVM_FFI_DLL int TVMFFINDArrayToDLPack(TVMFFIObjectHandle from, DLManagedTensor** out); + +/*! + * \brief Produce a managed NDArray from a DLPack tensor. + * \param from The source DLPack tensor. + * \param require_alignment The minimum alignment requored of the data + byte_offset. + * \param require_contiguous Boolean flag indicating if we need to check for contiguity. + * \param out The output NDArray handle. + * \return 0 when success, nonzero when failure happens + */ +TVM_FFI_DLL int TVMFFINDArrayFromDLPackVersioned(DLManagedTensorVersioned* from, + int32_t require_alignment, + int32_t require_contiguous, + TVMFFIObjectHandle* out); + +/*! + * \brief Produce a DLMangedTensor from the array that shares data memory with the array. + * \param from The source array. + * \param out The DLManagedTensor handle. + * \return 0 when success, nonzero when failure happens + */ +TVM_FFI_DLL int TVMFFINDArrayToDLPackVersioned(TVMFFIObjectHandle from, + DLManagedTensorVersioned** out); + +//--------------------------------------------------------------- +// Section: dtype string support APIs. +// These APIs are used to simplify the dtype printings during FFI +//--------------------------------------------------------------- + +/*! + * \brief Convert a string to a DLDataType. + * \param str The string to convert. + * \param out The output DLDataType. + * \return 0 when success, nonzero when failure happens + */ +TVM_FFI_DLL int TVMFFIDataTypeFromString(const TVMFFIByteArray* str, DLDataType* out); + +/*! + * \brief Convert a DLDataType to a string. + * \param dtype The DLDataType to convert. + * \param out The output string. + * \return 0 when success, nonzero when failure happens + * \note out is a String object that needs to be freed by the caller via TVMFFIObjectFree. + The content of string can be accessed via TVMFFIObjectGetByteArrayPtr. + */ +TVM_FFI_DLL int TVMFFIDataTypeToString(DLDataType dtype, TVMFFIObjectHandle* out); + +//------------------------------------------------------------ +// Section: Backend noexcept functions for internal use +// +// These functions are used internally and do not throw error +// instead the error will be logged and abort the process +// These are function are being called in startup or exit time +// so exception handling do not apply +//------------------------------------------------------------ +/*! + * \brief Get stack traceback in a string. + * \param filename The current file name. + * \param lineno The current line number + * \param func The current function + * \return The traceback string + * + * \note filename func and lino are only used as a backup info, most cases they are not needed. + * The return value is set to const char* to be more compatible across dll boundaries. + */ +TVM_FFI_DLL const TVMFFIByteArray* TVMFFITraceback(const char* filename, int lineno, + const char* func); + +/*! + * \brief Initialize the type info during runtime. + * + * When the function is first time called for a type, + * it will register the type to the type table in the runtime. + * + * If the static_tindex is non-negative, the function will + * allocate a runtime type index. + * Otherwise, we will populate the type table and return the static index. + * + * \param type_key The type key. + * \param static_type_index Static type index if any, can be -1, which means this is a dynamic index + * \param num_child_slots Number of slots reserved for its children. + * \param child_slots_can_overflow Whether to allow child to overflow the slots. + * \param parent_type_index Parent type index, pass in -1 if it is root. + * \param result The output type index + * + * \return 0 if success, -1 if error occured + */ +TVM_FFI_DLL int32_t TVMFFIGetOrAllocTypeIndex(const TVMFFIByteArray* type_key, + int32_t static_type_index, int32_t type_depth, + int32_t num_child_slots, + int32_t child_slots_can_overflow, + int32_t parent_type_index); + +/*! + * \brief Get dynamic type info by type index. + * + * \param type_index The type index + * \param result The output type information + * \return The type info + */ +TVM_FFI_DLL const TVMFFITypeInfo* TVMFFIGetTypeInfo(int32_t type_index); + +#ifdef __cplusplus +} // TVM_FFI_EXTERN_C +#endif + +//--------------------------------------------------------------- +// The following API defines static object field accessors +// for language bindings. +// +// They are defined in C++ inline functions for cleaner code. +// Note that they only have to do with address offset computation. +// So they can always be reimplemented in bindings when c++ is +// not available or when binding only wants to refer to the dll. +//---------------------------------------------------------------- +#ifdef __cplusplus +/*! + * \brief Get the type index of an object. + * \param obj The object handle. + * \return The type index. + */ +inline int32_t TVMFFIObjectGetTypeIndex(TVMFFIObjectHandle obj) { + return static_cast(obj)->type_index; +} + +/*! + * \brief Get the data pointer of a bytearray from a string or bytes object. + * \param obj The object handle. + * \return The data pointer. + */ +inline TVMFFIByteArray* TVMFFIBytesGetByteArrayPtr(TVMFFIObjectHandle obj) { + return reinterpret_cast(reinterpret_cast(obj) + sizeof(TVMFFIObject)); +} + +/*! + * \brief Get the data pointer of a ErrorInfo from an Error object. + * \param obj The object handle. + * \return The data pointer. + */ +inline TVMFFIErrorCell* TVMFFIErrorGetCellPtr(TVMFFIObjectHandle obj) { + return reinterpret_cast(reinterpret_cast(obj) + sizeof(TVMFFIObject)); +} + +/*! + * \brief Get the data pointer of a function cell from a function object. + * \param obj The object handle. + * \return The data pointer. + */ +inline TVMFFIFunctionCell* TVMFFIFunctionGetCellPtr(TVMFFIObjectHandle obj) { + return reinterpret_cast(reinterpret_cast(obj) + sizeof(TVMFFIObject)); +} + +/*! + * \brief Get the data pointer of a shape array from a shape object. + * \param obj The object handle. + * \return The data pointer. + */ +inline TVMFFIShapeCell* TVMFFIShapeGetCellPtr(TVMFFIObjectHandle obj) { + return reinterpret_cast(reinterpret_cast(obj) + sizeof(TVMFFIObject)); +} + +/*! + * \brief Get the DLTensor pointer from an NDArray object. + * \param obj The object handle. + * \return The DLTensor pointer. + */ +inline DLTensor* TVMFFINDArrayGetDLTensorPtr(TVMFFIObjectHandle obj) { + return reinterpret_cast(reinterpret_cast(obj) + sizeof(TVMFFIObject)); +} + +/*! + * \brief Create a DLDevice from a device type and device id. + * \param device_type The device type. + * \param device_id The device id. + * \return The DLDevice. + */ +inline DLDevice TVMFFIDLDeviceFromIntPair(int32_t device_type, int32_t device_id) { + return DLDevice{static_cast(device_type), device_id}; +} +#endif // __cplusplus +#endif // TVM_FFI_C_API_H_ diff --git a/ffi/include/tvm/ffi/cast.h b/ffi/include/tvm/ffi/cast.h new file mode 100644 index 000000000000..99069fb13b3c --- /dev/null +++ b/ffi/include/tvm/ffi/cast.h @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/ffi/cast.h + * \brief Value casting support + */ +#ifndef TVM_FFI_CAST_H_ +#define TVM_FFI_CAST_H_ + +#include +#include +#include +#include + +#include + +namespace tvm { +namespace ffi { +/*! + * \brief Get a reference type from a raw object ptr type + * + * It is always important to get a reference type + * if we want to return a value as reference or keep + * the object alive beyond the scope of the function. + * + * \param ptr The object pointer + * \tparam RefType The reference type + * \tparam ObjectType The object type + * \return The corresponding RefType + */ +template +TVM_FFI_INLINE RefType GetRef(const ObjectType* ptr) { + static_assert(std::is_base_of_v, + "Can only cast to the ref of same container type"); + + if constexpr (is_optional_type_v || RefType::_type_is_nullable) { + if (ptr == nullptr) { + return RefType(ObjectPtr(nullptr)); + } + } else { + TVM_FFI_ICHECK_NOTNULL(ptr); + } + return RefType(details::ObjectUnsafe::ObjectPtrFromUnowned( + const_cast(static_cast(ptr)))); +} + +/*! + * \brief Get an object ptr type from a raw object ptr. + * + * \param ptr The object pointer + * \tparam BaseType The reference type + * \tparam ObjectType The object type + * \return The corresponding RefType + */ +template +inline ObjectPtr GetObjectPtr(ObjectType* ptr) { + static_assert(std::is_base_of::value, + "Can only cast to the ref of same container type"); + return details::ObjectUnsafe::ObjectPtrFromUnowned(ptr); +} + +/*! + * \brief Downcast a base reference type to a more specific type. + * + * \param ref The input reference + * \return The corresponding SubRef. + * \tparam SubRef The target specific reference type. + * \tparam BaseRef the current reference type. + */ +template >> +inline SubRef Downcast(BaseRef ref) { + if (ref.defined()) { + if (!ref->template IsInstance()) { + TVM_FFI_THROW(TypeError) << "Downcast from " << ref->GetTypeKey() << " to " + << SubRef::ContainerType::_type_key << " failed."; + } + return SubRef(details::ObjectUnsafe::ObjectPtrFromObjectRef(std::move(ref))); + } else { + if constexpr (is_optional_type_v || SubRef::_type_is_nullable) { + return SubRef(ObjectPtr(nullptr)); + } + TVM_FFI_THROW(TypeError) << "Downcast from undefined(nullptr) to `" + << SubRef::ContainerType::_type_key + << "` is not allowed. Use Downcast> instead."; + TVM_FFI_UNREACHABLE(); + } +} + +/*! + * \brief Downcast any to a specific type + * + * \param ref The input reference + * \return The corresponding SubRef. + * \tparam T The target specific reference type. + */ +template +inline T Downcast(const Any& ref) { + if constexpr (std::is_same_v) { + return ref; + } else { + return ref.cast(); + } +} + +/*! + * \brief Downcast any to a specific type + * + * \param ref The input reference + * \return The corresponding SubRef. + * \tparam T The target specific reference type. + */ +template +inline T Downcast(Any&& ref) { + if constexpr (std::is_same_v) { + return std::move(ref); + } else { + return std::move(ref).cast(); + } +} + +/*! + * \brief Downcast std::optional to std::optional + * + * \param ref The input reference + * \return The corresponding SubRef. + * \tparam OptionalType The target optional type + */ +template >> +inline OptionalType Downcast(const std::optional& ref) { + if (ref.has_value()) { + if constexpr (std::is_same_v) { + return ref.value(); + } else { + return ref.value().cast(); + } + } else { + return OptionalType(std::nullopt); + } +} + +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_CAST_H_ diff --git a/ffi/include/tvm/ffi/container/array.h b/ffi/include/tvm/ffi/container/array.h new file mode 100644 index 000000000000..5922eacb10f5 --- /dev/null +++ b/ffi/include/tvm/ffi/container/array.h @@ -0,0 +1,1035 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ffi/container/array.h + * \brief Array type. + * + * tvm::ffi::Array is an erased type that contains list of content + */ +#ifndef TVM_FFI_CONTAINER_ARRAY_H_ +#define TVM_FFI_CONTAINER_ARRAY_H_ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace tvm { +namespace ffi { + +/*! \brief array node content in array */ +class ArrayObj : public Object, public details::InplaceArrayBase { + public: + /*! \return The size of the array */ + size_t size() const { return this->size_; } + + /*! + * \brief Read i-th element from array. + * \param i The index + * \return the i-th element. + */ + const Any at(int64_t i) const { return this->operator[](i); } + + /*! \return begin constant iterator */ + const Any* begin() const { return static_cast(InplaceArrayBase::AddressOf(0)); } + + /*! \return end constant iterator */ + const Any* end() const { return begin() + size_; } + + /*! \brief Release reference to all the elements */ + void clear() { ShrinkBy(size_); } + + /*! + * \brief Set i-th element of the array in-place + * \param i The index + * \param item The value to be set + */ + void SetItem(int64_t i, Any item) { this->operator[](i) = std::move(item); } + + /*! + * \brief Constructs a container and copy from another + * \param cap The capacity of the container + * \param from Source of the copy + * \return Ref-counted ArrayObj requested + */ + static ObjectPtr CopyFrom(int64_t cap, ArrayObj* from) { + int64_t size = from->size_; + if (size > cap) { + TVM_FFI_THROW(ValueError) << "not enough capacity"; + } + ObjectPtr p = ArrayObj::Empty(cap); + Any* write = p->MutableBegin(); + Any* read = from->MutableBegin(); + // To ensure exception safety, size is only incremented after the initialization succeeds + for (int64_t& i = p->size_ = 0; i < size; ++i) { + new (write++) Any(*read++); + } + return p; + } + + /*! + * \brief Constructs a container and move from another + * \param cap The capacity of the container + * \param from Source of the move + * \return Ref-counted ArrayObj requested + */ + static ObjectPtr MoveFrom(int64_t cap, ArrayObj* from) { + int64_t size = from->size_; + if (size > cap) { + TVM_FFI_THROW(RuntimeError) << "not enough capacity"; + } + ObjectPtr p = ArrayObj::Empty(cap); + Any* write = p->MutableBegin(); + Any* read = from->MutableBegin(); + // To ensure exception safety, size is only incremented after the initialization succeeds + for (int64_t& i = p->size_ = 0; i < size; ++i) { + new (write++) Any(std::move(*read++)); + } + from->size_ = 0; + return p; + } + + /*! + * \brief Constructs a container with n elements. Each element is a copy of val + * \param n The size of the container + * \param val The init value + * \return Ref-counted ArrayObj requested + */ + static ObjectPtr CreateRepeated(int64_t n, const Any& val) { + ObjectPtr p = ArrayObj::Empty(n); + Any* itr = p->MutableBegin(); + for (int64_t& i = p->size_ = 0; i < n; ++i) { + new (itr++) Any(val); + } + return p; + } + + static constexpr const int32_t _type_index = TypeIndex::kTVMFFIArray; + static constexpr const char* _type_key = "object.Array"; + static const constexpr bool _type_final = true; + TVM_FFI_DECLARE_STATIC_OBJECT_INFO(ArrayObj, Object); + + private: + /*! \return Size of initialized memory, used by InplaceArrayBase. */ + size_t GetSize() const { return this->size_; } + + /*! \return begin mutable iterator */ + Any* MutableBegin() const { return static_cast(InplaceArrayBase::AddressOf(0)); } + + /*! \return end mutable iterator */ + Any* MutableEnd() const { return MutableBegin() + size_; } + + /*! + * \brief Create an ArrayObj with the given capacity. + * \param n Required capacity + * \return Ref-counted ArrayObj requested + */ + static ObjectPtr Empty(int64_t n = kInitSize) { + TVM_FFI_ICHECK_GE(n, 0); + ObjectPtr p = make_inplace_array_object(n); + p->capacity_ = n; + p->size_ = 0; + return p; + } + + /*! + * \brief Inplace-initialize the elements starting idx from [first, last) + * \param idx The starting point + * \param first Begin of iterator + * \param last End of iterator + * \tparam IterType The type of iterator + * \return Self + */ + template + ArrayObj* InitRange(int64_t idx, IterType first, IterType last) { + Any* itr = MutableBegin() + idx; + for (; first != last; ++first) { + Any ref = *first; + new (itr++) Any(std::move(ref)); + } + return this; + } + + /*! + * \brief Move elements from right to left, requires src_begin > dst + * \param dst Destination + * \param src_begin The start point of copy (inclusive) + * \param src_end The end point of copy (exclusive) + * \return Self + */ + ArrayObj* MoveElementsLeft(int64_t dst, int64_t src_begin, int64_t src_end) { + Any* from = MutableBegin() + src_begin; + Any* to = MutableBegin() + dst; + while (src_begin++ != src_end) { + *to++ = std::move(*from++); + } + return this; + } + + /*! + * \brief Move elements from left to right, requires src_begin < dst + * \param dst Destination + * \param src_begin The start point of move (inclusive) + * \param src_end The end point of move (exclusive) + * \return Self + */ + ArrayObj* MoveElementsRight(int64_t dst, int64_t src_begin, int64_t src_end) { + Any* from = MutableBegin() + src_end; + Any* to = MutableBegin() + (src_end - src_begin + dst); + while (src_begin++ != src_end) { + *--to = std::move(*--from); + } + return this; + } + + /*! + * \brief Enlarges the size of the array + * \param delta Size enlarged, should be positive + * \param val Default value + * \return Self + */ + ArrayObj* EnlargeBy(int64_t delta, const Any& val = Any()) { + Any* itr = MutableEnd(); + while (delta-- > 0) { + new (itr++) Any(val); + ++size_; + } + return this; + } + + /*! + * \brief Shrinks the size of the array + * \param delta Size shrinked, should be positive + * \return Self + */ + ArrayObj* ShrinkBy(int64_t delta) { + Any* itr = MutableEnd(); + while (delta-- > 0) { + (--itr)->Any::~Any(); + --size_; + } + return this; + } + + /*! \brief Number of elements used */ + int64_t size_; + + /*! \brief Number of elements allocated */ + int64_t capacity_; + + /*! \brief Initial size of ArrayObj */ + static constexpr int64_t kInitSize = 4; + + /*! \brief Expansion factor of the Array */ + static constexpr int64_t kIncFactor = 2; + + // CRTP parent class + friend InplaceArrayBase; + + // Reference class + template + friend class Array; + + template + friend class Tuple; + + template + friend struct TypeTraits; + + // To specialize make_object + friend ObjectPtr make_object<>(); +}; + +/*! \brief Helper struct for type-checking + * + * is_valid_iterator::value will be true if IterType can + * be dereferenced into a type that can be stored in an Array, and + * false otherwise. + */ +template +struct is_valid_iterator + : std::bool_constant< + std::is_same_v< + T, std::remove_cv_t())>>> || + std::is_base_of_v< + T, std::remove_cv_t())>>>> { +}; + +template +struct is_valid_iterator, IterType> : is_valid_iterator {}; + +template +struct is_valid_iterator : std::true_type {}; + +template +inline constexpr bool is_valid_iterator_v = is_valid_iterator::value; + +/*! + * \brief Array, container representing a contiguous sequence of ObjectRefs. + * + * Array implements in-place copy-on-write semantics. + * + * As in typical copy-on-write, a method which would typically mutate the array + * instead opaquely copies the underlying container, and then acts on its copy. + * + * If the array has reference count equal to one, we directly update the + * container in place without copying. This is optimization is sound because + * when the reference count is equal to one this reference is guranteed to be + * the sole pointer to the container. + * + * + * operator[] only provides const access, use Set to mutate the content. + * \tparam T The content Value type, must be compatible with tvm::ffi::Any + */ +template >> +class Array : public ObjectRef { + public: + using value_type = T; + // constructors + /*! + * \brief default constructor + */ + Array() { data_ = ArrayObj::Empty(); } + Array(Array&& other) : ObjectRef(std::move(other.data_)) {} + Array(const Array& other) : ObjectRef(other.data_) {} + template >> + Array(Array&& other) : ObjectRef(std::move(other.data_)) {} + template >> + Array(const Array& other) : ObjectRef(other.data_) {} + + TVM_FFI_INLINE Array& operator=(Array&& other) { + data_ = std::move(other.data_); + return *this; + } + TVM_FFI_INLINE Array& operator=(const Array& other) { + data_ = other.data_; + return *this; + } + template >> + TVM_FFI_INLINE Array& operator=(Array&& other) { + data_ = std::move(other.data_); + return *this; + } + template >> + TVM_FFI_INLINE Array& operator=(const Array& other) { + data_ = other.data_; + return *this; + } + + /*! + * \brief constructor from pointer + * \param n the container pointer + */ + explicit Array(ObjectPtr n) : ObjectRef(n) {} + + /*! + * \brief Constructor from iterator + * \param first begin of iterator + * \param last end of iterator + * \tparam IterType The type of iterator + */ + template + Array(IterType first, IterType last) { + static_assert(is_valid_iterator_v, + "IterType cannot be inserted into a tvm::Array"); + Assign(first, last); + } + + /*! + * \brief constructor from initializer list + * \param init The initializer list + */ + Array(std::initializer_list init) { // NOLINT(*) + Assign(init.begin(), init.end()); + } + + /*! + * \brief constructor from vector + * \param init The vector + */ + Array(const std::vector& init) { // NOLINT(*) + Assign(init.begin(), init.end()); + } + + /*! + * \brief Constructs a container with n elements. Each element is a copy of val + * \param n The size of the container + * \param val The init value + */ + explicit Array(const size_t n, const T& val) { data_ = ArrayObj::CreateRepeated(n, val); } + + public: + // iterators + struct ValueConverter { + using ResultType = T; + static T convert(const Any& n) { + return details::AnyUnsafe::CopyFromAnyStorageAfterCheck(n); + } + }; + + using iterator = details::IterAdapter; + using reverse_iterator = details::ReverseIterAdapter; + + /*! \return begin iterator */ + iterator begin() const { return iterator(GetArrayObj()->begin()); } + + /*! \return end iterator */ + iterator end() const { return iterator(GetArrayObj()->end()); } + + /*! \return rbegin iterator */ + reverse_iterator rbegin() const { + // ArrayObj::end() is never nullptr + return reverse_iterator(GetArrayObj()->end() - 1); + } + + /*! \return rend iterator */ + reverse_iterator rend() const { + // ArrayObj::begin() is never nullptr + return reverse_iterator(GetArrayObj()->begin() - 1); + } + + public: + // const methods in std::vector + /*! + * \brief Immutably read i-th element from array. + * \param i The index + * \return the i-th element. + */ + const T operator[](int64_t i) const { + ArrayObj* p = GetArrayObj(); + if (p == nullptr) { + TVM_FFI_THROW(IndexError) << "cannot index a null array"; + } + if (i < 0 || i >= p->size_) { + TVM_FFI_THROW(IndexError) << "indexing " << i << " on an array of size " << p->size_; + } + return details::AnyUnsafe::CopyFromAnyStorageAfterCheck(*(p->begin() + i)); + } + + /*! \return The size of the array */ + size_t size() const { + ArrayObj* p = GetArrayObj(); + return p == nullptr ? 0 : GetArrayObj()->size_; + } + + /*! \return The capacity of the array */ + size_t capacity() const { + ArrayObj* p = GetArrayObj(); + return p == nullptr ? 0 : GetArrayObj()->capacity_; + } + + /*! \return Whether array is empty */ + bool empty() const { return size() == 0; } + + /*! \return The first element of the array */ + const T front() const { + ArrayObj* p = GetArrayObj(); + if (p == nullptr || p->size_ == 0) { + TVM_FFI_THROW(IndexError) << "cannot index a empty array"; + } + return details::AnyUnsafe::CopyFromAnyStorageAfterCheck(*(p->begin())); + } + + /*! \return The last element of the array */ + const T back() const { + ArrayObj* p = GetArrayObj(); + if (p == nullptr || p->size_ == 0) { + TVM_FFI_THROW(IndexError) << "cannot index a empty array"; + } + return details::AnyUnsafe::CopyFromAnyStorageAfterCheck(*(p->end() - 1)); + } + + public: + // mutation in std::vector, implements copy-on-write + /*! + * \brief push a new item to the back of the list + * \param item The item to be pushed. + */ + void push_back(const T& item) { + ArrayObj* p = CopyOnWrite(1); + p->EmplaceInit(p->size_++, item); + } + + /*! + * \brief Insert an element into the given position + * \param position An iterator pointing to the insertion point + * \param val The element to insert + */ + void insert(iterator position, const T& val) { + if (data_ == nullptr) { + TVM_FFI_THROW(RuntimeError) << "cannot insert a null array"; + } + int64_t idx = std::distance(begin(), position); + int64_t size = GetArrayObj()->size_; + auto addr = CopyOnWrite(1) // + ->EnlargeBy(1) // + ->MoveElementsRight(idx + 1, idx, size) // + ->MutableBegin(); + new (addr + idx) Any(val); + } + + /*! + * \brief Insert a range of elements into the given position + * \param position An iterator pointing to the insertion point + * \param first The begin iterator of the range + * \param last The end iterator of the range + */ + template + void insert(iterator position, IterType first, IterType last) { + static_assert(is_valid_iterator_v, + "IterType cannot be inserted into a tvm::Array"); + + if (first == last) { + return; + } + if (data_ == nullptr) { + TVM_FFI_THROW(RuntimeError) << "cannot insert a null array"; + } + int64_t idx = std::distance(begin(), position); + int64_t size = GetArrayObj()->size_; + int64_t numel = std::distance(first, last); + CopyOnWrite(numel) + ->EnlargeBy(numel) + ->MoveElementsRight(idx + numel, idx, size) + ->InitRange(idx, first, last); + } + + /*! \brief Remove the last item of the list */ + void pop_back() { + if (data_ == nullptr) { + TVM_FFI_THROW(RuntimeError) << "cannot pop_back a null array"; + } + int64_t size = GetArrayObj()->size_; + if (size == 0) { + TVM_FFI_THROW(RuntimeError) << "cannot pop_back an empty array"; + } + CopyOnWrite()->ShrinkBy(1); + } + + /*! + * \brief Erase an element on the given position + * \param position An iterator pointing to the element to be erased + */ + void erase(iterator position) { + if (data_ == nullptr) { + TVM_FFI_THROW(RuntimeError) << "cannot erase a null array"; + } + int64_t st = std::distance(begin(), position); + int64_t size = GetArrayObj()->size_; + if (st < 0 || st >= size) { + TVM_FFI_THROW(RuntimeError) << "cannot erase at index " << st << ", because Array size is " + << size; + } + CopyOnWrite() // + ->MoveElementsLeft(st, st + 1, size) // + ->ShrinkBy(1); + } + + /*! + * \brief Erase a given range of elements + * \param first The begin iterator of the range + * \param last The end iterator of the range + */ + void erase(iterator first, iterator last) { + if (first == last) { + return; + } + if (data_ == nullptr) { + TVM_FFI_THROW(RuntimeError) << "cannot erase a null array"; + } + int64_t size = GetArrayObj()->size_; + int64_t st = std::distance(begin(), first); + int64_t ed = std::distance(begin(), last); + if (st >= ed) { + TVM_FFI_THROW(IndexError) << "cannot erase array in range [" << st << ", " << ed << ")"; + } + if (st < 0 || st > size || ed < 0 || ed > size) { + TVM_FFI_THROW(IndexError) << "cannot erase array in range [" << st << ", " << ed << ")" + << ", because array size is " << size; + } + CopyOnWrite() // + ->MoveElementsLeft(st, ed, size) // + ->ShrinkBy(ed - st); + } + + /*! + * \brief Resize the array. + * \param n The new size. + */ + void resize(int64_t n) { + if (n < 0) { + TVM_FFI_THROW(ValueError) << "cannot resize an Array to negative size"; + } + if (data_ == nullptr) { + SwitchContainer(n); + return; + } + int64_t size = GetArrayObj()->size_; + if (size < n) { + CopyOnWrite(n - size)->EnlargeBy(n - size); + } else if (size > n) { + CopyOnWrite()->ShrinkBy(size - n); + } + } + + /*! + * \brief Make sure the list has the capacity of at least n + * \param n lower bound of the capacity + */ + void reserve(int64_t n) { + if (data_ == nullptr || n > GetArrayObj()->capacity_) { + SwitchContainer(n); + } + } + + /*! \brief Release reference to all the elements */ + void clear() { + if (data_ != nullptr) { + ArrayObj* p = CopyOnWrite(); + p->clear(); + } + } + + template + static size_t CalcCapacityImpl() { + return 0; + } + + template + static size_t CalcCapacityImpl(Array value, Args... args) { + return value.size() + CalcCapacityImpl(args...); + } + + template + static size_t CalcCapacityImpl(T value, Args... args) { + return 1 + CalcCapacityImpl(args...); + } + + template + static void AgregateImpl(Array& dest) {} // NOLINT(*) + + template + static void AgregateImpl(Array& dest, Array value, Args... args) { // NOLINT(*) + dest.insert(dest.end(), value.begin(), value.end()); + AgregateImpl(dest, args...); + } + + template + static void AgregateImpl(Array& dest, T value, Args... args) { // NOLINT(*) + dest.push_back(value); + AgregateImpl(dest, args...); + } + + public: + // Array's own methods + + /*! + * \brief set i-th element of the array. + * \param i The index + * \param value The value to be setted. + */ + void Set(int64_t i, T value) { + ArrayObj* p = this->CopyOnWrite(); + if (i < 0 || i >= p->size_) { + TVM_FFI_THROW(IndexError) << "indexing " << i << " on an array of size " << p->size_; + } + *(p->MutableBegin() + i) = std::move(value); + } + + /*! \return The underlying ArrayObj */ + ArrayObj* GetArrayObj() const { return static_cast(data_.get()); } + + /*! + * \brief Helper function to apply a map function onto the array. + * + * \param fmap The transformation function T -> U. + * + * \tparam F The type of the mutation function. + * + * \tparam U The type of the returned array, inferred from the + * return type of F. If overridden by the user, must be something + * that is convertible from the return type of F. + * + * \note This function performs copy on write optimization. If + * `fmap` returns an object of type `T`, and all elements of the + * array are mapped to themselves, then the returned array will be + * the same as the original, and reference counts of the elements in + * the array will not be incremented. + * + * \return The transformed array. + */ + template > + Array Map(F fmap) const { + return Array(MapHelper(data_, fmap)); + } + + /*! + * \brief Helper function to apply fmutate to mutate an array. + * \param fmutate The transformation function T -> T. + * \tparam F the type of the mutation function. + * \note This function performs copy on write optimization. + */ + template >>> + void MutateByApply(F fmutate) { + data_ = MapHelper(std::move(data_), fmutate); + } + + /*! + * \brief reset the array to content from iterator. + * \param first begin of iterator + * \param last end of iterator + * \tparam IterType The type of iterator + */ + template + void Assign(IterType first, IterType last) { + int64_t cap = std::distance(first, last); + if (cap < 0) { + TVM_FFI_THROW(ValueError) << "cannot construct an Array of negative size"; + } + ArrayObj* p = GetArrayObj(); + if (p != nullptr && data_.unique() && p->capacity_ >= cap) { + // do not have to make new space + p->clear(); + } else { + // create new space + data_ = ArrayObj::Empty(cap); + p = GetArrayObj(); + } + // To ensure exception safety, size is only incremented after the initialization succeeds + Any* itr = p->MutableBegin(); + for (int64_t& i = p->size_ = 0; i < cap; ++i, ++first, ++itr) { + new (itr) Any(*first); + } + } + + /*! + * \brief Copy on write semantics + * Do nothing if current handle is the unique copy of the array. + * Otherwise make a new copy of the array to ensure the current handle + * hold a unique copy. + * + * \return Handle to the internal node container(which ganrantees to be unique) + */ + ArrayObj* CopyOnWrite() { + if (data_ == nullptr) { + return SwitchContainer(ArrayObj::kInitSize); + } + if (!data_.unique()) { + return SwitchContainer(capacity()); + } + return static_cast(data_.get()); + } + + /*! \brief specify container node */ + using ContainerType = ArrayObj; + + /*! + * \brief Agregate arguments into a single Array + * \param args sequence of T or Array elements + * \return Agregated Array + */ + template + static Array Agregate(Args... args) { + Array result; + result.reserve(CalcCapacityImpl(args...)); + AgregateImpl(result, args...); + return result; + } + + private: + /*! + * \brief Implement copy-on-write semantics, and ensures capacity is enough for extra elements. + * \param reserve_extra Number of extra slots needed + * \return ArrayObj pointer to the unique copy + */ + ArrayObj* CopyOnWrite(int64_t reserve_extra) { + ArrayObj* p = GetArrayObj(); + if (p == nullptr) { + // necessary to get around the constexpr address issue before c++17 + const int64_t kInitSize = ArrayObj::kInitSize; + return SwitchContainer(std::max(kInitSize, reserve_extra)); + } + if (p->capacity_ >= p->size_ + reserve_extra) { + return CopyOnWrite(); + } + int64_t cap = p->capacity_ * ArrayObj::kIncFactor; + cap = std::max(cap, p->size_ + reserve_extra); + return SwitchContainer(cap); + } + + /*! + * \brief Move or copy the ArrayObj to new address with the given capacity + * \param capacity The capacity requirement of the new address + */ + ArrayObj* SwitchContainer(int64_t capacity) { + if (data_ == nullptr) { + data_ = ArrayObj::Empty(capacity); + } else if (data_.unique()) { + data_ = ArrayObj::MoveFrom(capacity, GetArrayObj()); + } else { + data_ = ArrayObj::CopyFrom(capacity, GetArrayObj()); + } + return static_cast(data_.get()); + } + + /*! \brief Helper method for mutate/map + * + * A helper function used internally by both `Array::Map` and + * `Array::MutateInPlace`. Given an array of data, apply the + * mapping function to each element, returning the collected array. + * Applies both mutate-in-place and copy-on-write optimizations, if + * possible. + * + * \param data A pointer to the ArrayObj containing input data. + * Passed by value to allow for mutate-in-place optimizations. + * + * \param fmap The mapping function + * + * \tparam F The type of the mutation function. + * + * \tparam U The output type of the mutation function. Inferred + * from the callable type given. Must inherit from ObjectRef. + * + * \return The mapped array. Depending on whether mutate-in-place + * or copy-on-write optimizations were applicable, may be the same + * underlying array as the `data` parameter. + */ + template > + static ObjectPtr MapHelper(ObjectPtr data, F fmap) { + if (data == nullptr) { + return nullptr; + } + + TVM_FFI_ICHECK(data->IsInstance()); + + constexpr bool is_same_output_type = std::is_same_v; + + if constexpr (is_same_output_type) { + if (data.unique()) { + // Mutate-in-place path. Only allowed if the output type U is + // the same as type T, we have a mutable this*, and there are + // no other shared copies of the array. + auto arr = static_cast(data.get()); + for (auto it = arr->MutableBegin(); it != arr->MutableEnd(); it++) { + T value = details::AnyUnsafe::CopyFromAnyStorageAfterCheck(*it); + // reset the original value to nullptr, to ensure unique ownership + it->reset(); + T mapped = fmap(std::move(value)); + *it = std::move(mapped); + } + return data; + } + } + + constexpr bool compatible_types = is_valid_iterator_v || is_valid_iterator_v; + + ObjectPtr output = nullptr; + auto arr = static_cast(data.get()); + + auto it = arr->begin(); + if constexpr (compatible_types) { + // Copy-on-write path, if the output Array might be + // represented by the same underlying array as the existing + // Array. Typically, this is for functions that map `T` to + // `T`, but can also apply to functions that map `T` to + // `Optional`, or that map `T` to a subclass or superclass of + // `T`. + bool all_identical = true; + for (; it != arr->end(); it++) { + U mapped = fmap(details::AnyUnsafe::CopyFromAnyStorageAfterCheck(*it)); + if (!(*it).same_as(mapped)) { + // At least one mapped element is different than the + // original. Therefore, prepare the output array, + // consisting of any previous elements that had mapped to + // themselves (if any), and the element that didn't map to + // itself. + // + // We cannot use `U()` as the default object, as `U` may be + // a non-nullable type. Since the default `Any()` + // will be overwritten before returning, all objects will be + // of type `U` for the calling scope. + all_identical = false; + output = ArrayObj::CreateRepeated(arr->size(), Any()); + output->InitRange(0, arr->begin(), it); + output->SetItem(it - arr->begin(), std::move(mapped)); + it++; + break; + } + } + if (all_identical) { + return data; + } + } else { + // Path for incompatible types. The constexpr check for + // compatible types isn't strictly necessary, as the first + // (*it).same_as(mapped) would return false, but we might as well + // avoid it altogether. + // + // We cannot use `U()` as the default object, as `U` may be a + // non-nullable type. Since the default `Any()` will be + // overwritten before returning, all objects will be of type `U` + // for the calling scope. + output = ArrayObj::CreateRepeated(arr->size(), Any()); + } + + // Normal path for incompatible types, or post-copy path for + // copy-on-write instances. + // + // If the types are incompatible, then at this point `output` is + // empty, and `it` points to the first element of the input. + // + // If the types were compatible, then at this point `output` + // contains zero or more elements that mapped to themselves + // followed by the first element that does not map to itself, and + // `it` points to the element just after the first element that + // does not map to itself. Because at least one element has been + // changed, we no longer have the opportunity to avoid a copy, so + // we don't need to check the result. + // + // In both cases, `it` points to the next element to be processed, + // so we can either start or resume the iteration from that point, + // with no further checks on the result. + for (; it != arr->end(); it++) { + U mapped = fmap(details::AnyUnsafe::CopyFromAnyStorageAfterCheck(*it)); + output->SetItem(it - arr->begin(), std::move(mapped)); + } + + return output; + } + template + friend class Array; +}; + +/*! + * \brief Concat two Arrays. + * \param lhs first Array to be concatenated. + * \param rhs second Array to be concatenated. + * \return The concatenated Array. Original Arrays are kept unchanged. + */ +template || + TypeTraits::convert_enabled>> +inline Array Concat(Array lhs, const Array& rhs) { + for (const auto& x : rhs) { + lhs.push_back(x); + } + return std::move(lhs); +} + +// Specialize make_object to make sure it is correct. +template <> +inline ObjectPtr make_object() { + return ArrayObj::Empty(); +} + +// Traits for Array +template +inline constexpr bool use_default_type_traits_v> = false; + +template +struct TypeTraits> : public ObjectRefTypeTraitsBase> { + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIArray; + using ObjectRefTypeTraitsBase>::CopyFromAnyStorageAfterCheck; + + static TVM_FFI_INLINE std::string GetMismatchTypeInfo(const TVMFFIAny* src) { + if (src->type_index != TypeIndex::kTVMFFIArray) { + return TypeTraitsBase::GetMismatchTypeInfo(src); + } + if constexpr (!std::is_same_v) { + const ArrayObj* n = reinterpret_cast(src->v_obj); + for (size_t i = 0; i < n->size(); i++) { + const Any& any_v = (*n)[i]; + // CheckAnyStorage is cheaper than as + if (details::AnyUnsafe::CheckAnyStorage(any_v)) continue; + // try see if p is convertible to T + if (any_v.as()) continue; + // now report the accurate mismatch information + return "Array[index " + std::to_string(i) + ": " + + details::AnyUnsafe::GetMismatchTypeInfo(any_v) + "]"; + } + } + TVM_FFI_THROW(InternalError) << "Cannot reach here"; + TVM_FFI_UNREACHABLE(); + } + + static TVM_FFI_INLINE bool CheckAnyStorage(const TVMFFIAny* src) { + if (src->type_index != TypeIndex::kTVMFFIArray) return false; + if constexpr (std::is_same_v) { + return true; + } else { + const ArrayObj* n = reinterpret_cast(src->v_obj); + for (size_t i = 0; i < n->size(); i++) { + const Any& any_v = (*n)[i]; + if (!details::AnyUnsafe::CheckAnyStorage(any_v)) return false; + } + return true; + } + } + + static TVM_FFI_INLINE std::optional> TryConvertFromAnyView(const TVMFFIAny* src) { + // try to run conversion. + if (src->type_index != TypeIndex::kTVMFFIArray) return std::nullopt; + if constexpr (!std::is_same_v) { + const ArrayObj* n = reinterpret_cast(src->v_obj); + bool storage_check = [&]() { + for (size_t i = 0; i < n->size(); i++) { + const Any& any_v = (*n)[i]; + if (!details::AnyUnsafe::CheckAnyStorage(any_v)) return false; + } + return true; + }(); + // fast path, if storage check passes, we can return the array directly. + if (storage_check) { + return CopyFromAnyStorageAfterCheck(src); + } + // slow path, try to run a conversion to Array + Array result; + result.reserve(n->size()); + for (size_t i = 0; i < n->size(); i++) { + const Any& any_v = (*n)[i]; + if (auto opt_v = any_v.as()) { + result.push_back(*std::move(opt_v)); + } else { + return std::nullopt; + } + } + return result; + } else { + return CopyFromAnyStorageAfterCheck(src); + } + } + + static TVM_FFI_INLINE std::string TypeStr() { return "Array<" + details::Type2Str::v() + ">"; } +}; + +namespace details { +template +inline constexpr bool type_contains_v, Array> = type_contains_v; +} // namespace details + +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_CONTAINER_ARRAY_H_ diff --git a/ffi/include/tvm/ffi/container/container_details.h b/ffi/include/tvm/ffi/container/container_details.h new file mode 100644 index 000000000000..51e130f37385 --- /dev/null +++ b/ffi/include/tvm/ffi/container/container_details.h @@ -0,0 +1,338 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ffi/container/container_details.h + * \brief Common utilities for typed container types. + */ +#ifndef TVM_FFI_CONTAINER_CONTAINER_DETAILS_H_ +#define TVM_FFI_CONTAINER_CONTAINER_DETAILS_H_ + +#include +#include + +#include +#include +#include +#include + +namespace tvm { +namespace ffi { +namespace details { +/*! + * \brief Base template for classes with array like memory layout. + * + * It provides general methods to access the memory. The memory + * layout is ArrayType + [ElemType]. The alignment of ArrayType + * and ElemType is handled by the memory allocator. + * + * \tparam ArrayType The array header type, contains object specific metadata. + * \tparam ElemType The type of objects stored in the array right after + * ArrayType. + * + * \code + * // Example usage of the template to define a simple array wrapper + * class ArrayObj : public tvm::ffi::details::InplaceArrayBase { + * public: + * // Wrap EmplaceInit to initialize the elements + * template + * void Init(Iterator begin, Iterator end) { + * size_t num_elems = std::distance(begin, end); + * auto it = begin; + * this->size = 0; + * for (size_t i = 0; i < num_elems; ++i) { + * InplaceArrayBase::EmplaceInit(i, *it++); + * this->size++; + * } + * } + * } + * + * void test_function() { + * vector fields; + * auto ptr = make_inplace_array_object(fields.size()); + * ptr->Init(fields.begin(), fields.end()); + * + * // Access the 0th element in the array. + * assert(ptr->operator[](0) == fields[0]); + * } + * + * \endcode + */ +template +class InplaceArrayBase { + public: + /*! + * \brief Access element at index + * \param idx The index of the element. + * \return Const reference to ElemType at the index. + */ + const ElemType& operator[](size_t idx) const { + size_t size = Self()->GetSize(); + if (idx > size) { + TVM_FFI_THROW(IndexError) << "Index " << idx << " out of bounds " << size; + } + return *(reinterpret_cast(AddressOf(idx))); + } + + /*! + * \brief Access element at index + * \param idx The index of the element. + * \return Reference to ElemType at the index. + */ + ElemType& operator[](size_t idx) { + size_t size = Self()->GetSize(); + if (idx > size) { + TVM_FFI_THROW(IndexError) << "Index " << idx << " out of bounds " << size; + } + return *(reinterpret_cast(AddressOf(idx))); + } + + /*! + * \brief Destroy the Inplace Array Base object + */ + ~InplaceArrayBase() { + if constexpr (!(std::is_standard_layout::value && std::is_trivial::value)) { + size_t size = Self()->GetSize(); + for (size_t i = 0; i < size; ++i) { + ElemType* fp = reinterpret_cast(AddressOf(i)); + fp->ElemType::~ElemType(); + } + } + } + + protected: + /*! + * \brief Construct a value in place with the arguments. + * + * \tparam Args Type parameters of the arguments. + * \param idx Index of the element. + * \param args Arguments to construct the new value. + * + * \note Please make sure ArrayType::GetSize returns 0 before first call of + * EmplaceInit, and increment GetSize by 1 each time EmplaceInit succeeds. + */ + template + void EmplaceInit(size_t idx, Args&&... args) { + void* field_ptr = AddressOf(idx); + new (field_ptr) ElemType(std::forward(args)...); + } + + /*! + * \brief Return the self object for the array. + * + * \return Pointer to ArrayType. + */ + inline ArrayType* Self() const { + return static_cast(const_cast(this)); + } + + /*! + * \brief Return the raw pointer to the element at idx. + * + * \param idx The index of the element. + * \return Raw pointer to the element. + */ + void* AddressOf(size_t idx) const { + static_assert( + alignof(ArrayType) % alignof(ElemType) == 0 && sizeof(ArrayType) % alignof(ElemType) == 0, + "The size and alignment of ArrayType should respect " + "ElemType's alignment."); + + size_t kDataStart = sizeof(ArrayType); + ArrayType* self = Self(); + char* data_start = reinterpret_cast(self) + kDataStart; + return data_start + idx * sizeof(ElemType); + } +}; + +/*! + * \brief iterator adapter that adapts TIter to return another type. + * \tparam Converter a struct that contains converting function + * \tparam TIter the content iterator type. + */ +template +class IterAdapter { + public: + using difference_type = typename std::iterator_traits::difference_type; + using value_type = typename Converter::ResultType; + using pointer = typename Converter::ResultType*; + using reference = typename Converter::ResultType&; + using iterator_category = typename std::iterator_traits::iterator_category; + + explicit IterAdapter(TIter iter) : iter_(iter) {} + IterAdapter& operator++() { + ++iter_; + return *this; + } + IterAdapter& operator--() { + --iter_; + return *this; + } + IterAdapter operator++(int) { + IterAdapter copy = *this; + ++iter_; + return copy; + } + IterAdapter operator--(int) { + IterAdapter copy = *this; + --iter_; + return copy; + } + + IterAdapter operator+(difference_type offset) const { return IterAdapter(iter_ + offset); } + + IterAdapter operator-(difference_type offset) const { return IterAdapter(iter_ - offset); } + + template + typename std::enable_if::value, + typename T::difference_type>::type inline + operator-(const IterAdapter& rhs) const { + return iter_ - rhs.iter_; + } + + bool operator==(IterAdapter other) const { return iter_ == other.iter_; } + bool operator!=(IterAdapter other) const { return !(*this == other); } + const value_type operator*() const { return Converter::convert(*iter_); } + + private: + TIter iter_; +}; + +/*! + * \brief iterator adapter that adapts TIter to return another type. + * \tparam Converter a struct that contains converting function + * \tparam TIter the content iterator type. + */ +template +class ReverseIterAdapter { + public: + using difference_type = typename std::iterator_traits::difference_type; + using value_type = typename Converter::ResultType; + using pointer = typename Converter::ResultType*; + using reference = typename Converter::ResultType&; // NOLINT(*) + using iterator_category = typename std::iterator_traits::iterator_category; + + explicit ReverseIterAdapter(TIter iter) : iter_(iter) {} + ReverseIterAdapter& operator++() { + --iter_; + return *this; + } + ReverseIterAdapter& operator--() { + ++iter_; + return *this; + } + ReverseIterAdapter operator++(int) { + ReverseIterAdapter copy = *this; + --iter_; + return copy; + } + ReverseIterAdapter operator--(int) { + ReverseIterAdapter copy = *this; + ++iter_; + return copy; + } + ReverseIterAdapter operator+(difference_type offset) const { + return ReverseIterAdapter(iter_ - offset); + } + + template + typename std::enable_if::value, + typename T::difference_type>::type inline + operator-(const ReverseIterAdapter& rhs) const { + return rhs.iter_ - iter_; + } + + bool operator==(ReverseIterAdapter other) const { return iter_ == other.iter_; } + bool operator!=(ReverseIterAdapter other) const { return !(*this == other); } + const value_type operator*() const { return Converter::convert(*iter_); } + + private: + TIter iter_; +}; + +/*! + * \brief Check if T is compatible with Any. + * + * \tparam T The type to check. + * \return True if T is compatible with Any, false otherwise. + */ +template +inline constexpr bool storage_enabled_v = std::is_same_v || TypeTraits::storage_enabled; + +/*! + * \brief Check if all T are compatible with Any. + * + * \tparam T The type to check. + * \return True if T is compatible with Any, false otherwise. + */ +template +inline constexpr bool all_storage_enabled_v = (storage_enabled_v && ...); + +/** + * \brief Check if Any storage of Derived can always be directly used as Base. + * + * \tparam Base The base type. + * \tparam Derived The derived type. + * \return True if Derived's storage can be used as Base's storage, false otherwise. + */ +template +inline constexpr bool type_contains_v = + std::is_base_of_v || std::is_same_v; +// special case for Any +template +inline constexpr bool type_contains_v = true; + +/*! + * \brief Create a string of the container type. + * \tparam V The types of the elements in the container. + * \param name The name of the container type. + * \return A string of the container type. + */ +template +std::string ContainerTypeStr(const char* name) { + std::stringstream ss; + // helper to construct concated string of TypeStr + class TypeStrHelper { + public: + TypeStrHelper(std::stringstream& stream) : stream_(stream) {} // NOLINT(*) + + TypeStrHelper& operator<<(const std::string& str) { + if (counter_ > 0) { + stream_ << ", "; + } + stream_ << str; + counter_++; + return *this; + } + + private: + std::stringstream& stream_; // NOLINT(*) + int counter_ = 0; + }; + TypeStrHelper helper(ss); + ss << name << '<'; + (helper << ... << Type2Str::v()); + ss << '>'; + return ss.str(); +} + +} // namespace details +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_CONTAINER_CONTAINER_DETAILS_H_ diff --git a/ffi/include/tvm/ffi/container/map.h b/ffi/include/tvm/ffi/container/map.h new file mode 100644 index 000000000000..7d805d61e9ee --- /dev/null +++ b/ffi/include/tvm/ffi/container/map.h @@ -0,0 +1,1606 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ffi/container/map.h + * \brief Runtime Map container types. + */ +#ifndef TVM_FFI_CONTAINER_MAP_H_ +#define TVM_FFI_CONTAINER_MAP_H_ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace ffi { + +#if TVM_FFI_DEBUG_WITH_ABI_CHANGE +#define TVM_FFI_MAP_FAIL_IF_CHANGED() \ + TVM_FFI_ICHECK(state_marker == self->state_marker) << "Concurrent modification of the Map"; +#else +#define TVM_FFI_MAP_FAIL_IF_CHANGED() +#endif // TVM_FFI_DEBUG_WITH_ABI_CHANGE + +/*! \brief Shared content of all specializations of hash map */ +class MapObj : public Object { + public: + /*! \brief Type of the keys in the hash map */ + using key_type = Any; + /*! \brief Type of the values in the hash map */ + using mapped_type = Any; + /*! \brief Type of value stored in the hash map */ + using KVType = std::pair; + /*! \brief Iterator class */ + class iterator; + + static_assert(std::is_standard_layout::value, "KVType is not standard layout"); + static_assert(sizeof(KVType) == 32, "sizeof(KVType) incorrect"); + + static constexpr const int32_t _type_index = TypeIndex::kTVMFFIMap; + static constexpr const char* _type_key = "object.Map"; + static const constexpr bool _type_final = true; + TVM_FFI_DECLARE_STATIC_OBJECT_INFO(MapObj, Object); + + /*! + * \brief Number of elements in the SmallMapObj + * \return The result + */ + size_t size() const { return size_; } + /*! + * \brief Count the number of times a key exists in the hash map + * \param key The indexing key + * \return The result, 0 or 1 + */ + size_t count(const key_type& key) const; + /*! + * \brief Index value associated with a key, throw exception if the key does not exist + * \param key The indexing key + * \return The const reference to the value + */ + const mapped_type& at(const key_type& key) const; + /*! + * \brief Index value associated with a key, throw exception if the key does not exist + * \param key The indexing key + * \return The mutable reference to the value + */ + mapped_type& at(const key_type& key); + /*! \return begin iterator */ + iterator begin() const; + /*! \return end iterator */ + iterator end() const; + /*! + * \brief Index value associated with a key + * \param key The indexing key + * \return The iterator of the entry associated with the key, end iterator if not exists + */ + iterator find(const key_type& key) const; + /*! + * \brief Erase the entry associated with the iterator + * \param position The iterator + */ + void erase(const iterator& position); + /*! + * \brief Erase the entry associated with the key, do nothing if not exists + * \param key The indexing key + */ + void erase(const key_type& key) { erase(find(key)); } + + class iterator { + public: + using iterator_category = std::forward_iterator_tag; + using difference_type = int64_t; + using value_type = KVType; + using pointer = KVType*; + using reference = KVType&; +/*! \brief Default constructor */ +#if TVM_FFI_DEBUG_WITH_ABI_CHANGE + iterator() : state_marker(0), index(0), self(nullptr) {} +#else + iterator() : index(0), self(nullptr) {} +#endif // TVM_FFI_DEBUG_WITH_ABI_CHANGE + /*! \brief Compare iterators */ + bool operator==(const iterator& other) const { + TVM_FFI_MAP_FAIL_IF_CHANGED() + return index == other.index && self == other.self; + } + /*! \brief Compare iterators */ + bool operator!=(const iterator& other) const { return !(*this == other); } + /*! \brief De-reference iterators */ + pointer operator->() const; + /*! \brief De-reference iterators */ + reference operator*() const { + TVM_FFI_MAP_FAIL_IF_CHANGED() + return *((*this).operator->()); + } + /*! \brief Prefix self increment, e.g. ++iter */ + iterator& operator++(); + /*! \brief Prefix self decrement, e.g. --iter */ + iterator& operator--(); + /*! \brief Suffix self increment */ + iterator operator++(int) { + TVM_FFI_MAP_FAIL_IF_CHANGED() + iterator copy = *this; + ++(*this); + return copy; + } + /*! \brief Suffix self decrement */ + iterator operator--(int) { + TVM_FFI_MAP_FAIL_IF_CHANGED() + iterator copy = *this; + --(*this); + return copy; + } + + protected: +#if TVM_FFI_DEBUG_WITH_ABI_CHANGE + uint64_t state_marker; + /*! \brief Construct by value */ + iterator(uint64_t index, const MapObj* self) + : state_marker(self->state_marker), index(index), self(self) {} + +#else + iterator(uint64_t index, const MapObj* self) : index(index), self(self) {} +#endif // TVM_FFI_DEBUG_WITH_ABI_CHANGE + /*! \brief The position on the array */ + uint64_t index; + /*! \brief The container it points to */ + const MapObj* self; + + friend class DenseMapObj; + friend class SmallMapObj; + }; + /*! + * \brief Create an empty container + * \return The object created + */ + static inline ObjectPtr Empty(); + + protected: +#if TVM_FFI_DEBUG_WITH_ABI_CHANGE + uint64_t state_marker; +#endif // TVM_FFI_DEBUG_WITH_ABI_CHANGE + /*! + * \brief Create the map using contents from the given iterators. + * \param first Begin of iterator + * \param last End of iterator + * \tparam IterType The type of iterator + * \return ObjectPtr to the map created + */ + template + static inline ObjectPtr CreateFromRange(IterType first, IterType last); + /*! + * \brief InsertMaybeReHash an entry into the given hash map + * \param kv The entry to be inserted + * \param map The pointer to the map, can be changed if re-hashing happens + */ + static inline void InsertMaybeReHash(KVType&& kv, ObjectPtr* map); + /*! + * \brief Create an empty container with elements copying from another SmallMapObj + * \param from The source container + * \return The object created + */ + static inline ObjectPtr CopyFrom(MapObj* from); + /*! \brief number of slots minus 1 */ + uint64_t slots_; + /*! \brief number of entries in the container */ + uint64_t size_; + // Reference class + template + friend class Map; +}; + +/*! \brief A specialization of small-sized hash map */ +class SmallMapObj : public MapObj, public details::InplaceArrayBase { + private: + static constexpr uint64_t kInitSize = 2; + static constexpr uint64_t kMaxSize = 4; + + public: + using MapObj::iterator; + using MapObj::KVType; + + /*! \brief Defaults to the destructor of InplaceArrayBase */ + ~SmallMapObj() = default; + /*! + * \brief Count the number of times a key exists in the SmallMapObj + * \param key The indexing key + * \return The result, 0 or 1 + */ + size_t count(const key_type& key) const { return find(key).index < size_; } + /*! + * \brief Index value associated with a key, throw exception if the key does not exist + * \param key The indexing key + * \return The const reference to the value + */ + const mapped_type& at(const key_type& key) const { + iterator itr = find(key); + if (itr.index >= size_) { + TVM_FFI_THROW(KeyError) << "key is not in Map"; + } + return itr->second; + } + /*! + * \brief Index value associated with a key, throw exception if the key does not exist + * \param key The indexing key + * \return The mutable reference to the value + */ + mapped_type& at(const key_type& key) { + iterator itr = find(key); + if (itr.index >= size_) { + TVM_FFI_THROW(KeyError) << "key is not in Map"; + } + return itr->second; + } + /*! \return begin iterator */ + iterator begin() const { return iterator(0, this); } + /*! \return end iterator */ + iterator end() const { return iterator(size_, this); } + /*! + * \brief Index value associated with a key + * \param key The indexing key + * \return The iterator of the entry associated with the key, end iterator if not exists + */ + iterator find(const key_type& key) const { + KVType* ptr = static_cast(AddressOf(0)); + for (uint64_t i = 0; i < size_; ++i, ++ptr) { + if (AnyEqual()(ptr->first, key)) { + return iterator(i, this); + } + } + return iterator(size_, this); + } + /*! + * \brief Erase the entry associated with the iterator + * \param position The iterator + */ + void erase(const iterator& position) { Erase(position.index); } + + private: + /*! + * \brief Remove a position in SmallMapObj + * \param index The position to be removed + */ + void Erase(const uint64_t index) { + if (index >= size_) { + return; + } + KVType* begin = static_cast(AddressOf(0)); + // call destructor to destroy the item in `begin + index` + // Explicit call Any::~Any() to destroy the Any object + // Favor this over ~KVType as MSVC may not support ~KVType (need the original name) + (begin + index)->first.Any::~Any(); + (begin + index)->second.Any::~Any(); + // IMPORTANT: We do direct raw memmove to bring later items to the current position + // to preserve the order of insertion. + // This works because direct memory copy preserves the Any's move semantics. + if (index + 1 < size_) { + std::memmove(reinterpret_cast(begin + index), + reinterpret_cast(begin + index + 1), + (size_ - index - 1) * sizeof(KVType)); + } + size_ -= 1; + } + /*! + * \brief Create an empty container + * \param n Number of empty slots + * \return The object created + */ + static ObjectPtr Empty(uint64_t n = kInitSize) { + using ::tvm::ffi::make_inplace_array_object; + ObjectPtr p = make_inplace_array_object(n); + p->size_ = 0; + p->slots_ = n; + return p; + } + /*! + * \brief Create an empty container initialized with a given range + * \param n Number of empty slots + * \param first begin of iterator + * \param last end of iterator + * \tparam IterType The type of iterator + * \return The object created + */ + template + static ObjectPtr CreateFromRange(uint64_t n, IterType first, IterType last) { + ObjectPtr p = Empty(n); + KVType* ptr = static_cast(p->AddressOf(0)); + for (; first != last; ++first, ++p->size_) { + new (ptr++) KVType(*first); + } + return p; + } + /*! + * \brief Create an empty container with elements copying from another SmallMapObj + * \param from The source container + * \return The object created + */ + static ObjectPtr CopyFrom(SmallMapObj* from) { + KVType* first = static_cast(from->AddressOf(0)); + KVType* last = first + from->size_; + return CreateFromRange(from->size_, first, last); + } + /*! + * \brief InsertMaybeReHash an entry into the given hash map + * \param kv The entry to be inserted + * \param map The pointer to the map, can be changed if re-hashing happens + */ + static void InsertMaybeReHash(KVType&& kv, ObjectPtr* map) { + SmallMapObj* map_node = static_cast(map->get()); + iterator itr = map_node->find(kv.first); + if (itr.index < map_node->size_) { + itr->second = kv.second; + return; + } + if (map_node->size_ < map_node->slots_) { + KVType* ptr = static_cast(map_node->AddressOf(map_node->size_)); + new (ptr) KVType(std::move(kv)); + ++map_node->size_; + return; + } + uint64_t next_size = std::max(map_node->slots_ * 2, uint64_t(kInitSize)); + next_size = std::min(next_size, uint64_t(kMaxSize)); + TVM_FFI_ICHECK_GT(next_size, map_node->slots_); + ObjectPtr new_map = CreateFromRange(next_size, map_node->begin(), map_node->end()); + InsertMaybeReHash(std::move(kv), &new_map); + *map = std::move(new_map); + } + /*! + * \brief Increment the pointer + * \param index The pointer to be incremented + * \return The increased pointer + */ + uint64_t IncItr(uint64_t index) const { return index + 1 < size_ ? index + 1 : size_; } + /*! + * \brief Decrement the pointer + * \param index The pointer to be decremented + * \return The decreased pointer + */ + uint64_t DecItr(uint64_t index) const { return index > 0 ? index - 1 : size_; } + /*! + * \brief De-reference the pointer + * \param index The pointer to be dereferenced + * \return The result + */ + KVType* DeRefItr(uint64_t index) const { return static_cast(AddressOf(index)); } + /*! \brief A size function used by InplaceArrayBase */ + uint64_t GetSize() const { return size_; } + + protected: + friend class MapObj; + friend class DenseMapObj; + friend class details::InplaceArrayBase; +}; + +/*! \brief A specialization of hash map that implements the idea of array-based hash map. + * Another reference implementation can be found [1]. + * + * A. Overview + * + * DenseMapObj did several improvements over traditional separate chaining hash, + * in terms of cache locality, memory footprints and data organization. + * + * A1. Implicit linked list. For better cache locality, instead of using linked list + * explicitly for each bucket, we store list data into a single array that spans contiguously + * in memory, and then carefully design access patterns to make sure most of them fall into + * a single cache line. + * + * A2. 1-byte metadata. There is only 1 byte overhead for each slot in the array to indexing and + * traversal. This can be divided in 3 parts. + * 1) Reserved code: (0b11111111)_2 indicates a slot is empty; (0b11111110)_2 indicates protected, + * which means the slot is empty but not allowed to be written. + * 2) If not empty or protected, the highest bit is used to indicate whether data in the slot is + * head of a linked list. + * 3) The rest 7 bits are used as the "next pointer" (i.e. pointer to the next element). On 64-bit + * architecture, an ordinary pointer can take up to 8 bytes, which is not acceptable overhead when + * dealing with 16-byte ObjectRef pairs. Based on a commonly noticed fact that the lists are + * relatively short (length <= 3) in hash maps, we follow [1]'s idea that only allows the pointer to + * be one of the 126 possible values, i.e. if the next element of i-th slot is (i + x)-th element, + * then x must be one of the 126 pre-defined values. + * + * A3. Data blocking. We organize the array in the way that every 16 elements forms a data block. + * The 16-byte metadata of those 16 elements are stored together, followed by the real data, i.e. + * 16 key-value pairs. + * + * B. Implementation details + * + * B1. Power-of-2 table size and Fibonacci Hashing. We use power-of-two as table size to avoid + * modulo for more efficient arithmetics. To make the hash-to-slot mapping distribute more evenly, + * we use the Fibonacci Hashing [2] trick. + * + * B2. Traverse a linked list in the array. + * 1) List head. Assume Fibonacci Hashing maps a given key to slot i, if metadata at slot i + * indicates that it is list head, then we found the head; otherwise the list is empty. No probing + * is done in this procedure. 2) Next element. To find the next element of a non-empty slot i, we + * look at the last 7 bits of the metadata at slot i. If they are all zeros, then it is the end of + * list; otherwise, we know that the next element is (i + candidates[the-last-7-bits]). + * + * B3. InsertMaybeReHash an element. Following B2, we first traverse the linked list to see if this + * element is in the linked list, and if not, we put it at the end by probing the next empty + * position in one of the 126 candidate positions. If the linked list does not even exist, but the + * slot for list head has been occupied by another linked list, we should find this intruder another + * place. + * + * B4. Quadratic probing with triangle numbers. In open address hashing, it is provable that probing + * with triangle numbers can traverse power-of-2-sized table [3]. In our algorithm, we follow the + * suggestion in [1] that also use triangle numbers for "next pointer" as well as sparing for list + * head. + * + * [1] https://github.com/skarupke/flat_hash_map + * [2] https://programmingpraxis.com/2018/06/19/fibonacci-hash/ + * [3] https://fgiesen.wordpress.com/2015/02/22/triangular-numbers-mod-2n/ + */ +class DenseMapObj : public MapObj { + private: + /*! \brief The number of elements in a memory block */ + static constexpr int kBlockCap = 16; + /*! \brief Maximum load factor of the hash map */ + static constexpr double kMaxLoadFactor = 0.99; + /*! \brief Binary representation of the metadata of an empty slot */ + static constexpr uint8_t kEmptySlot = uint8_t(0b11111111); + /*! \brief Binary representation of the metadata of a protected slot */ + static constexpr uint8_t kProtectedSlot = uint8_t(0b11111110); + /*! \brief Number of probing choices available */ + static constexpr int kNumJumpDists = 126; + /*! \brief Index indicator to indicate an invalid index */ + static constexpr uint64_t kInvalidIndex = std::numeric_limits::max(); + /*! \brief Head of the implicit linked list */ + struct ListNode; + /*! \brief item type of the dense map, including a kv data and prev/next pointer */ + struct ItemType { + KVType data; + uint64_t prev = kInvalidIndex; + uint64_t next = kInvalidIndex; + + explicit ItemType(KVType&& data) : data(std::move(data)) {} + explicit ItemType(key_type key, mapped_type value) : data(key, value) {} + }; + /*! \brief POD type of a block of memory */ + struct Block { + uint8_t bytes[kBlockCap + kBlockCap * sizeof(ItemType)]; + }; + static_assert(sizeof(Block) == kBlockCap * (sizeof(ItemType) + 1), "sizeof(Block) incorrect"); + static_assert(std::is_standard_layout::value, "Block is not standard layout"); + + public: + using MapObj::iterator; + + /*! + * \brief Destroy the DenseMapObj + */ + ~DenseMapObj() { this->Reset(); } + /*! \return The number of elements of the key */ + size_t count(const key_type& key) const { return !Search(key).IsNone(); } + /*! + * \brief Index value associated with a key, throw exception if the key does not exist + * \param key The indexing key + * \return The const reference to the value + */ + const mapped_type& at(const key_type& key) const { return At(key); } + /*! + * \brief Index value associated with a key, throw exception if the key does not exist + * \param key The indexing key + * \return The mutable reference to the value + */ + mapped_type& at(const key_type& key) { return At(key); } + /*! + * \brief Index value associated with a key + * \param key The indexing key + * \return The iterator of the entry associated with the key, end iterator if not exists + */ + iterator find(const key_type& key) const { + ListNode node = Search(key); + return node.IsNone() ? end() : iterator(node.index, this); + } + /*! + * \brief Erase the entry associated with the iterator + * \param position The iterator + */ + void erase(const iterator& position) { + uint64_t index = position.index; + if (position.self != nullptr && index <= this->slots_) { + Erase(ListNode(index, this)); + } + } + /*! \return begin iterator */ + iterator begin() const { return iterator(iter_list_head_, this); } + /*! \return end iterator */ + iterator end() const { return iterator(kInvalidIndex, this); } + + private: + /*! + * \brief Unlink the entry from iterator list + * \param node The node to be unlinked + * \note This function is usually used before deletion, + * and it does not change data content of the node. + */ + void IterListUnlink(ListNode node) { + // update head and tail of iterator list if needed + if (node.Item().prev == kInvalidIndex) { + iter_list_head_ = node.Item().next; + } else { + ListNode prev_node(node.Item().prev, this); + prev_node.Item().next = node.Item().next; + } + if (node.Item().next == kInvalidIndex) { + iter_list_tail_ = node.Item().prev; + } else { + ListNode next_node(node.Item().next, this); + next_node.Item().prev = node.Item().prev; + } + } + /*! + * \brief Insert the entry into tail of iterator list + * \param node The node to be inserted + * \note this function does not change data content of the node. + */ + void IterListPushBack(ListNode node) { + node.Item().prev = iter_list_tail_; + node.Item().next = kInvalidIndex; + if (iter_list_tail_ != kInvalidIndex) { + ListNode prev_node(iter_list_tail_, this); + prev_node.Item().next = node.index; + } + if (iter_list_head_ == kInvalidIndex) { + iter_list_head_ = node.index; + } + iter_list_tail_ = node.index; + } + /*! + * \brief Replace node src by dst in the iter list + * \param src The source node + * \param dst The destination node, must be empty + * \note This function does not change data content of the nodes, + * which needs to be updated by the caller. + */ + void IterListReplaceNodeBy(ListNode src, ListNode dst) { + // set link correctly on the dst + dst.Item().prev = src.Item().prev; + dst.Item().next = src.Item().next; + // update prev and next of dst + if (dst.Item().prev == kInvalidIndex) { + iter_list_head_ = dst.index; + } else { + ListNode prev_node(dst.Item().prev, this); + prev_node.Item().next = dst.index; + } + if (dst.Item().next == kInvalidIndex) { + iter_list_tail_ = dst.index; + } else { + ListNode next_node(dst.Item().next, this); + next_node.Item().prev = dst.index; + } + } + /*! + * \brief Search for the given key + * \param key The key + * \return ListNode that associated with the key + */ + ListNode Search(const key_type& key) const { + if (this->size_ == 0) { + return ListNode(); + } + for (ListNode iter = GetListHead(AnyHash()(key)); !iter.IsNone(); iter.MoveToNext(this)) { + if (AnyEqual()(key, iter.Key())) { + return iter; + } + } + return ListNode(); + } + /*! + * \brief Search for the given key, throw exception if not exists + * \param key The key + * \return ListNode that associated with the key + */ + mapped_type& At(const key_type& key) const { + ListNode iter = Search(key); + if (iter.IsNone()) { + TVM_FFI_THROW(IndexError) << "key is not in Map"; + } + return iter.Val(); + } + /*! + * \brief Try to insert a key, or do nothing if already exists + * \param key The indexing key + * \param result The linked-list entry found or just constructed + * \return A boolean, indicating if actual insertion happens + */ + bool TryInsert(const key_type& key, ListNode* result) { + if (slots_ == 0) { + return false; + } + // required that `iter` to be the head of a linked list through which we can iterator + ListNode iter = IndexFromHash(AnyHash()(key)); + // `iter` can be: 1) empty; 2) body of an irrelevant list; 3) head of the relevant list + // Case 1: empty + if (iter.IsEmpty()) { + iter.NewHead(ItemType(key, Any(nullptr))); + this->size_ += 1; + *result = iter; + return true; + } + // Case 2: body of an irrelevant list + if (!iter.IsHead()) { + // we move the elements around and construct the single-element linked list + return IsFull() ? false : TrySpareListHead(iter, key, result); + } + // Case 3: head of the relevant list + // we iterate through the linked list until the end + // make sure `iter` is the previous element of `next` + ListNode next = iter; + do { + // find equal item, do not insert + if (AnyEqual()(key, next.Key())) { + // we plan to take next, so we need to unlink it from iterator list + IterListUnlink(next); + *result = next; + return true; + } + // make sure `iter` is the previous element of `next` + iter = next; + } while (next.MoveToNext(this)); + // `iter` is the tail of the linked list + // always check capacity before insertion + if (IsFull()) { + return false; + } + // find the next empty slot + uint8_t jump; + if (!iter.GetNextEmpty(this, &jump, result)) { + return false; + } + result->NewTail(ItemType(key, Any(nullptr))); + // link `iter` to `empty`, and move forward + iter.SetJump(jump); + this->size_ += 1; + return true; + } + /*! + * \brief Spare an entry to be the head of a linked list. + * As described in B3, during insertion, it is possible that the entire linked list does not + * exist, but the slot of its head has been occupied by other linked lists. In this case, we need + * to spare the slot by moving away the elements to another valid empty one to make insertion + * possible. + * \param target The given entry to be spared + * \param key The indexing key + * \param result The linked-list entry constructed as the head + * \return A boolean, if actual insertion happens + */ + bool TrySpareListHead(ListNode target, const key_type& key, ListNode* result) { + // `target` is not the head of the linked list + // move the original item of `target` (if any) + // and construct new item on the position `target` + // To make `target` empty, we + // 1) find `w` the previous element of `target` in the linked list + // 2) copy the linked list starting from `r = target` + // 3) paste them after `w` + // read from the linked list after `r` + ListNode r = target; + // write to the tail of `w` + ListNode w = target.FindPrev(this); + // after `target` is moved, we disallow writing to the slot + bool is_first = true; + uint8_t r_meta, jump; + ListNode empty; + do { + // `jump` describes how `w` is jumped to `empty` + // rehash if there is no empty space after `w` + if (!w.GetNextEmpty(this, &jump, &empty)) { + return false; + } + // move `r` to `empty` + // first move the data over + empty.NewTail(ItemType(std::move(r.Data()))); + // then move link list chain of r to empty + // this needs to happen after NewTail so empty's prev/next get updated + IterListReplaceNodeBy(r, empty); + // explicit call destructor to destroy the item in `r` + r.DestructData(); + // clear the metadata of `r` + r_meta = r.Meta(); + if (is_first) { + is_first = false; + r.SetProtected(); + } else { + r.SetEmpty(); + } + // link `w` to `empty`, and move forward + w.SetJump(jump); + w = empty; + // move `r` forward as well + } while (r.MoveToNext(this, r_meta)); + // finally we have done moving the linked list + // fill data_ into `target` + target.NewHead(ItemType(key, Any(nullptr))); + this->size_ += 1; + *result = target; + return true; + } + /*! + * \brief Remove a ListNode + * \param iter The node to be removed + */ + void Erase(const ListNode& iter) { + this->size_ -= 1; + if (!iter.HasNext()) { + // `iter` is the last + if (!iter.IsHead()) { + // cut the link if there is any + iter.FindPrev(this).SetJump(0); + } + // unlink the node from iterator list + IterListUnlink(iter); + // IMPORTANT: must explicit call destructor `iter` to avoid memory leak + // This is because we need to recycle iter's data + iter.DestructData(); + // set the meta data to be empty + iter.SetEmpty(); + } else { + ListNode last = iter, prev = iter; + for (last.MoveToNext(this); last.HasNext(); prev = last, last.MoveToNext(this)) { + } + // needs to first unlink iter from the list + IterListUnlink(iter); + // move data from last to iter + iter.Data() = std::move(last.Data()); + // Move link chain of iter to last as we stores last node to the new iter loc. + IterListReplaceNodeBy(last, iter); + // IMPORTANT: must explicit call destructor `last` to avoid memory leak + // likely we don't need this in this particular case because Any move behavior + // keep it here to be safe so code do not depend on specific move behavior of KVType + last.DestructData(); + // set the meta data to be empty + last.SetEmpty(); + prev.SetJump(0); + } + } + /*! \brief Clear the container to empty, release all entries and memory acquired */ + void Reset() { + uint64_t n_blocks = CalcNumBlocks(this->slots_); + for (uint64_t bi = 0; bi < n_blocks; ++bi) { + uint8_t* meta_ptr = data_[bi].bytes; + ItemType* data_ptr = reinterpret_cast(data_[bi].bytes + kBlockCap); + for (int j = 0; j < kBlockCap; ++j, ++meta_ptr, ++data_ptr) { + uint8_t& meta = *meta_ptr; + if (meta != uint8_t(kProtectedSlot) && meta != uint8_t(kEmptySlot)) { + meta = uint8_t(kEmptySlot); + data_ptr->ItemType::~ItemType(); + } + } + } + ReleaseMemory(); + } + /*! \brief Release the memory acquired by the container without deleting its entries stored inside + */ + void ReleaseMemory() { + delete[] data_; + data_ = nullptr; + slots_ = 0; + size_ = 0; + fib_shift_ = 63; + } + /*! + * \brief Create an empty container + * \param fib_shift The fib shift provided + * \param n_slots Number of slots required, should be power-of-two + * \return The object created + */ + static ObjectPtr Empty(uint32_t fib_shift, uint64_t n_slots) { + TVM_FFI_ICHECK_GT(n_slots, uint64_t(SmallMapObj::kMaxSize)); + ObjectPtr p = make_object(); + uint64_t n_blocks = CalcNumBlocks(n_slots - 1); + Block* block = p->data_ = new Block[n_blocks]; + p->slots_ = n_slots - 1; + p->size_ = 0; + p->fib_shift_ = fib_shift; + p->iter_list_head_ = kInvalidIndex; + p->iter_list_tail_ = kInvalidIndex; + for (uint64_t i = 0; i < n_blocks; ++i, ++block) { + std::fill(block->bytes, block->bytes + kBlockCap, uint8_t(kEmptySlot)); + } + return p; + } + /*! + * \brief Create an empty container with elements copying from another DenseMapObj + * \param from The source container + * \return The object created + */ + static ObjectPtr CopyFrom(DenseMapObj* from) { + ObjectPtr p = make_object(); + uint64_t n_blocks = CalcNumBlocks(from->slots_); + p->data_ = new Block[n_blocks]; + p->slots_ = from->slots_; + p->size_ = from->size_; + p->fib_shift_ = from->fib_shift_; + p->iter_list_head_ = from->iter_list_head_; + p->iter_list_tail_ = from->iter_list_tail_; + for (uint64_t bi = 0; bi < n_blocks; ++bi) { + uint8_t* meta_ptr_from = from->data_[bi].bytes; + ItemType* data_ptr_from = reinterpret_cast(from->data_[bi].bytes + kBlockCap); + uint8_t* meta_ptr_to = p->data_[bi].bytes; + ItemType* data_ptr_to = reinterpret_cast(p->data_[bi].bytes + kBlockCap); + for (int j = 0; j < kBlockCap; + ++j, ++meta_ptr_from, ++data_ptr_from, ++meta_ptr_to, ++data_ptr_to) { + uint8_t& meta = *meta_ptr_to = *meta_ptr_from; + TVM_FFI_ICHECK(meta != kProtectedSlot); + if (meta != uint8_t(kEmptySlot)) { + new (data_ptr_to) ItemType(*data_ptr_from); + } + } + } + return p; + } + /*! + * \brief InsertMaybeReHash an entry into the given hash map + * \param kv The entry to be inserted + * \param map The pointer to the map, can be changed if re-hashing happens + */ + static void InsertMaybeReHash(KVType&& kv, ObjectPtr* map) { + DenseMapObj* map_node = static_cast(map->get()); + ListNode iter; + // Try to insert. If succeed, we simply return + if (map_node->TryInsert(kv.first, &iter)) { + iter.Val() = std::move(kv.second); + // update the iter list relation + map_node->IterListPushBack(iter); + return; + } + TVM_FFI_ICHECK_GT(map_node->slots_, uint64_t(SmallMapObj::kMaxSize)); + // Otherwise, start rehash + ObjectPtr p = Empty(map_node->fib_shift_ - 1, map_node->slots_ * 2 + 2); + + // need to insert in the same order as the original map + for (uint64_t index = map_node->iter_list_head_; index != kInvalidIndex;) { + ListNode node(index, map_node); + // now try move src_data into the new map, note that src may still not + // be fully consumed into the call, but destructor will be called. + InsertMaybeReHash(std::move(node.Data()), &p); + // Important, needs to explicit call destructor in case move did remove + // node's internal item + index = node.Item().next; + // IMPORTANT: must explicit call destructor `node` to avoid memory leak + // We must call node.DestructData() here. + // This is because std::move() arguments in IterMaybeReHash may or may not + // explicitly move out the node.Data() + // Remove this call will cause memory leak very likely. + node.DestructData(); + } + InsertMaybeReHash(std::move(kv), &p); + map_node->ReleaseMemory(); + *map = p; + } + /*! + * \brief Check whether the hash table is full + * \return A boolean indicating whether hash table is full + */ + bool IsFull() const { return size_ + 1 > (slots_ + 1) * kMaxLoadFactor; } + /*! + * \brief Increment the pointer + * \param index The pointer to be incremented + * \return The increased pointer + */ + uint64_t IncItr(uint64_t index) const { + // keep at the end of iterator + if (index == kInvalidIndex) { + return index; + } + ListNode node(index, this); + return node.Item().next; + } + /*! + * \brief Decrement the pointer + * \param index The pointer to be decremented + * \return The decreased pointer + */ + uint64_t DecItr(uint64_t index) const { + // this is the end iterator, we need to return tail. + if (index == kInvalidIndex) { + return iter_list_tail_; + } + // circle around the iterator list, which is OK + ListNode node(index, this); + return node.Item().prev; + } + /*! + * \brief De-reference the pointer + * \param index The pointer to be dereferenced + * \return The result + */ + KVType* DeRefItr(uint64_t index) const { return &ListNode(index, this).Data(); } + /*! \brief Construct from hash code */ + ListNode IndexFromHash(uint64_t hash_value) const { + return ListNode(FibHash(hash_value, fib_shift_), this); + } + /*! \brief Construct from hash code if the position is head of list */ + ListNode GetListHead(uint64_t hash_value) const { + ListNode node = IndexFromHash(hash_value); + return node.IsHead() ? node : ListNode(); + } + /*! \brief Construct the number of blocks in the hash table */ + static uint64_t CalcNumBlocks(uint64_t n_slots_m1) { + uint64_t n_slots = n_slots_m1 > 0 ? n_slots_m1 + 1 : 0; + return (n_slots + kBlockCap - 1) / kBlockCap; + } + /*! + * \brief Calculate the power-of-2 table size given the lower-bound of required capacity. + * \param cap The lower-bound of the required capacity + * \param fib_shift The result shift for Fibonacci Hashing + * \param n_slots The result number of slots + */ + static void CalcTableSize(uint64_t cap, uint32_t* fib_shift, uint64_t* n_slots) { + uint32_t shift = 64; + uint64_t slots = 1; + for (uint64_t c = cap; c; c >>= 1) { + shift -= 1; + slots <<= 1; + } + TVM_FFI_ICHECK_GT(slots, cap); + if (slots < cap * 2) { + *fib_shift = shift - 1; + *n_slots = slots << 1; + } else { + *fib_shift = shift; + *n_slots = slots; + } + } + /*! + * \brief Fibonacci Hashing, maps a hash code to an index in a power-of-2-sized table. + * See also: https://programmingpraxis.com/2018/06/19/fibonacci-hash/. + * \param hash_value The raw hash value + * \param fib_shift The shift in Fibonacci Hashing + * \return An index calculated using Fibonacci Hashing + */ + static uint64_t FibHash(uint64_t hash_value, uint32_t fib_shift) { + constexpr uint64_t coeff = 11400714819323198485ull; + return (coeff * hash_value) >> fib_shift; + } + /*! \brief The implicit in-place linked list used to index a chain */ + struct ListNode { + /*! \brief Construct None */ + ListNode() : index(0), block(nullptr) {} + /*! \brief Construct from position */ + ListNode(uint64_t index, const DenseMapObj* self) + : index(index), block(self->data_ + (index / kBlockCap)) {} + /*! \brief Metadata on the entry */ + uint8_t& Meta() const { return *(block->bytes + index % kBlockCap); } + /*! \brief Data on the entry */ + ItemType& Item() const { + return *(reinterpret_cast(block->bytes + kBlockCap + + (index % kBlockCap) * sizeof(ItemType))); + } + /*! \brief Data on the entry */ + KVType& Data() const { return Item().data; } + /*! \brief Key on the entry */ + key_type& Key() const { return Data().first; } + /*! \brief Value on the entry */ + mapped_type& Val() const { return Data().second; } + /*! \brief If the entry is head of linked list */ + bool IsHead() const { return (Meta() & 0b10000000) == 0b00000000; } + /*! \brief If the entry is none */ + bool IsNone() const { return block == nullptr; } + /*! \brief If the entry is empty slot */ + bool IsEmpty() const { return Meta() == uint8_t(kEmptySlot); } + /*! \brief If the entry is protected slot */ + bool IsProtected() const { return Meta() == uint8_t(kProtectedSlot); } + /*! \brief Set the entry to be empty */ + void SetEmpty() const { Meta() = uint8_t(kEmptySlot); } + /*! \brief Destruct the item in the entry */ + void DestructData() const { + // explicit call destructor to destroy the item + // Favor this over ~KVType as MSVC may not support ~KVType (need the original name) + (&Data())->first.Any::~Any(); + (&Data())->second.Any::~Any(); + } + /*! \brief Set the entry to be protected */ + void SetProtected() const { Meta() = uint8_t(kProtectedSlot); } + /*! \brief Set the entry's jump to its next entry */ + void SetJump(uint8_t jump) const { (Meta() &= 0b10000000) |= jump; } + /*! \brief Construct a head of linked list in-place */ + void NewHead(ItemType v) const { + Meta() = 0b00000000; + new (&Item()) ItemType(std::move(v)); + } + /*! \brief Construct a tail of linked list in-place */ + void NewTail(ItemType v) const { + Meta() = 0b10000000; + new (&Item()) ItemType(std::move(v)); + } + /*! \brief If the entry has next entry on the linked list */ + bool HasNext() const { return NextProbeLocation(Meta() & 0b01111111) != 0; } + /*! \brief Move the entry to the next entry on the linked list */ + bool MoveToNext(const DenseMapObj* self, uint8_t meta) { + uint64_t offset = NextProbeLocation(meta & 0b01111111); + if (offset == 0) { + index = 0; + + block = nullptr; + return false; + } + index = (index + offset) & (self->slots_); + block = self->data_ + (index / kBlockCap); + return true; + } + /*! \brief Move the entry to the next entry on the linked list */ + bool MoveToNext(const DenseMapObj* self) { return MoveToNext(self, Meta()); } + /*! \brief Get the previous entry on the linked list */ + ListNode FindPrev(const DenseMapObj* self) const { + // start from the head of the linked list, which must exist + ListNode next = self->IndexFromHash(AnyHash()(Key())); + // `prev` is always the previous item of `next` + ListNode prev = next; + for (next.MoveToNext(self); index != next.index; prev = next, next.MoveToNext(self)) { + } + return prev; + } + /*! \brief Get the next empty jump */ + bool GetNextEmpty(const DenseMapObj* self, uint8_t* jump, ListNode* result) const { + for (uint8_t idx = 1; idx < kNumJumpDists; ++idx) { + ListNode candidate((index + NextProbeLocation(idx)) & (self->slots_), self); + if (candidate.IsEmpty()) { + *jump = idx; + *result = candidate; + return true; + } + } + return false; + } + /*! \brief Index on the real array */ + uint64_t index; + /*! \brief Pointer to the actual block */ + Block* block; + }; + + protected: + /*! \brief fib shift in Fibonacci Hashing */ + uint32_t fib_shift_; + /*! \brief array of data blocks */ + Block* data_; + /*! \brief the head of iterator list */ + uint64_t iter_list_head_ = kInvalidIndex; + /*! \brief the tail of iterator list */ + uint64_t iter_list_tail_ = kInvalidIndex; + + static uint64_t NextProbeLocation(size_t index) { + /* clang-format off */ + /*! \brief Candidates of probing distance */ + static const uint64_t kNextProbeLocation[kNumJumpDists] { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + // Quadratic probing with triangle numbers. See also: + // 1) https://en.wikipedia.org/wiki/Quadratic_probing + // 2) https://fgiesen.wordpress.com/2015/02/22/triangular-numbers-mod-2n/ + // 3) https://github.com/skarupke/flat_hash_map + 21, 28, 36, 45, 55, 66, 78, 91, 105, 120, + 136, 153, 171, 190, 210, 231, 253, 276, 300, 325, + 351, 378, 406, 435, 465, 496, 528, 561, 595, 630, + 666, 703, 741, 780, 820, 861, 903, 946, 990, 1035, + 1081, 1128, 1176, 1225, 1275, 1326, 1378, 1431, 1485, 1540, + 1596, 1653, 1711, 1770, 1830, 1891, 1953, 2016, 2080, 2145, + 2211, 2278, 2346, 2415, 2485, 2556, 2628, + // larger triangle numbers + 8515, 19110, 42778, 96141, 216153, + 486591, 1092981, 2458653, 5532801, 12442566, + 27993903, 62983476, 141717030, 318844378, 717352503, + 1614057336, 3631522476, 8170957530, 18384510628, 41364789378, + 93070452520, 209408356380, 471168559170, 1060128894105, 2385289465695, + 5366898840628, 12075518705635, 27169915244790, 61132312065111, 137547689707000, + 309482283181501, 696335127828753, 1566753995631385, 3525196511162271, 7931691992677701, + 17846306936293605, 40154190677507445, 90346928918121501, 203280589587557251, + 457381325854679626, 1029107982097042876, 2315492959180353330, 5209859154120846435, + }; + /* clang-format on */ + return kNextProbeLocation[index]; + } + friend class MapObj; +}; + +#define TVM_DISPATCH_MAP(base, var, body) \ + { \ + using TSmall = SmallMapObj*; \ + using TDense = DenseMapObj*; \ + uint64_t slots = base->slots_; \ + if (slots <= SmallMapObj::kMaxSize) { \ + TSmall var = static_cast(base); \ + body; \ + } else { \ + TDense var = static_cast(base); \ + body; \ + } \ + } + +#define TVM_DISPATCH_MAP_CONST(base, var, body) \ + { \ + using TSmall = const SmallMapObj*; \ + using TDense = const DenseMapObj*; \ + uint64_t slots = base->slots_; \ + if (slots <= SmallMapObj::kMaxSize) { \ + TSmall var = static_cast(base); \ + body; \ + } else { \ + TDense var = static_cast(base); \ + body; \ + } \ + } + +inline MapObj::iterator::pointer MapObj::iterator::operator->() const { + TVM_FFI_MAP_FAIL_IF_CHANGED() + TVM_DISPATCH_MAP_CONST(self, p, { return p->DeRefItr(index); }); +} + +inline MapObj::iterator& MapObj::iterator::operator++() { + TVM_FFI_MAP_FAIL_IF_CHANGED() + TVM_DISPATCH_MAP_CONST(self, p, { + index = p->IncItr(index); + return *this; + }); +} + +inline MapObj::iterator& MapObj::iterator::operator--() { + TVM_FFI_MAP_FAIL_IF_CHANGED() + TVM_DISPATCH_MAP_CONST(self, p, { + index = p->DecItr(index); + return *this; + }); +} + +inline size_t MapObj::count(const key_type& key) const { + TVM_DISPATCH_MAP_CONST(this, p, { return p->count(key); }); +} + +inline const MapObj::mapped_type& MapObj::at(const MapObj::key_type& key) const { + TVM_DISPATCH_MAP_CONST(this, p, { return p->at(key); }); +} + +inline MapObj::mapped_type& MapObj::at(const MapObj::key_type& key) { + TVM_DISPATCH_MAP(this, p, { return p->at(key); }); +} + +inline MapObj::iterator MapObj::begin() const { + TVM_DISPATCH_MAP_CONST(this, p, { return p->begin(); }); +} + +inline MapObj::iterator MapObj::end() const { + TVM_DISPATCH_MAP_CONST(this, p, { return p->end(); }); +} + +inline MapObj::iterator MapObj::find(const MapObj::key_type& key) const { + TVM_DISPATCH_MAP_CONST(this, p, { return p->find(key); }); +} + +inline void MapObj::erase(const MapObj::iterator& position) { + TVM_DISPATCH_MAP(this, p, { return p->erase(position); }); +} + +#undef TVM_DISPATCH_MAP +#undef TVM_DISPATCH_MAP_CONST + +inline ObjectPtr MapObj::Empty() { return SmallMapObj::Empty(); } + +inline ObjectPtr MapObj::CopyFrom(MapObj* from) { + if (from->slots_ <= SmallMapObj::kMaxSize) { + return SmallMapObj::CopyFrom(static_cast(from)); + } else { + return DenseMapObj::CopyFrom(static_cast(from)); + } +} + +template +inline ObjectPtr MapObj::CreateFromRange(IterType first, IterType last) { + int64_t _cap = std::distance(first, last); + if (_cap < 0) { + return SmallMapObj::Empty(); + } + uint64_t cap = static_cast(_cap); + if (cap < SmallMapObj::kMaxSize) { + return SmallMapObj::CreateFromRange(cap, first, last); + } + uint32_t fib_shift; + uint64_t n_slots; + DenseMapObj::CalcTableSize(cap, &fib_shift, &n_slots); + ObjectPtr obj = DenseMapObj::Empty(fib_shift, n_slots); + for (; first != last; ++first) { + KVType kv(*first); + DenseMapObj::InsertMaybeReHash(std::move(kv), &obj); + } + return obj; +} + +inline void MapObj::InsertMaybeReHash(KVType&& kv, ObjectPtr* map) { + constexpr uint64_t kSmallMapMaxSize = SmallMapObj::kMaxSize; + MapObj* base = static_cast(map->get()); +#if TVM_FFI_DEBUG_WITH_ABI_CHANGE + base->state_marker++; +#endif // TVM_FFI_DEBUG_WITH_ABI_CHANGE + if (base->slots_ < kSmallMapMaxSize) { + SmallMapObj::InsertMaybeReHash(std::move(kv), map); + } else if (base->slots_ == kSmallMapMaxSize) { + if (base->size_ < base->slots_) { + SmallMapObj::InsertMaybeReHash(std::move(kv), map); + } else { + ObjectPtr new_map = MapObj::CreateFromRange(base->begin(), base->end()); + DenseMapObj::InsertMaybeReHash(std::move(kv), &new_map); + *map = std::move(new_map); + } + } else { + DenseMapObj::InsertMaybeReHash(std::move(kv), map); + } +} + +template <> +inline ObjectPtr make_object<>() = delete; + +/*! + * \brief Map container of NodeRef->NodeRef in DSL graph. + * Map implements copy on write semantics, which means map is mutable + * but copy will happen when array is referenced in more than two places. + * + * operator[] only provide const acces, use Set to mutate the content. + * \tparam K The key NodeRef type. + * \tparam V The value NodeRef type. + */ +template && + details::storage_enabled_v>> +class Map : public ObjectRef { + public: + using key_type = K; + using mapped_type = V; + class iterator; + /*! + * \brief default constructor + */ + Map() { data_ = MapObj::Empty(); } + /*! + * \brief move constructor + * \param other source + */ + Map(Map&& other) : ObjectRef(std::move(other.data_)) {} + /*! + * \brief copy constructor + * \param other source + */ + Map(const Map& other) : ObjectRef(other.data_) {} + + template && + details::type_contains_v>> + Map(Map&& other) : ObjectRef(std::move(other.data_)) {} + + template && + details::type_contains_v>> + Map(const Map& other) : ObjectRef(other.data_) {} + Map& operator=(Map&& other) { + data_ = std::move(other.data_); + return *this; + } + Map& operator=(const Map& other) { + data_ = other.data_; + return *this; + } + + template && + details::type_contains_v>> + Map& operator=(Map&& other) { + data_ = std::move(other.data_); + return *this; + } + + template && + details::type_contains_v>> + Map& operator=(const Map& other) { + data_ = other.data_; + return *this; + } + /*! + * \brief constructor from pointer + * \param n the container pointer + */ + explicit Map(ObjectPtr n) : ObjectRef(n) {} + /*! + * \brief constructor from iterator + * \param begin begin of iterator + * \param end end of iterator + * \tparam IterType The type of iterator + */ + template + Map(IterType begin, IterType end) { + data_ = MapObj::CreateFromRange(begin, end); + } + /*! + * \brief constructor from initializer list + * \param init The initalizer list + */ + Map(std::initializer_list> init) { + data_ = MapObj::CreateFromRange(init.begin(), init.end()); + } + /*! + * \brief constructor from unordered_map + * \param init The unordered_map + */ + template + Map(const std::unordered_map& init) { // NOLINT(*) + data_ = MapObj::CreateFromRange(init.begin(), init.end()); + } + /*! + * \brief Read element from map. + * \param key The key + * \return the corresonding element. + */ + const V at(const K& key) const { + return details::AnyUnsafe::CopyFromAnyStorageAfterCheck(GetMapObj()->at(key)); + } + /*! + * \brief Read element from map. + * \param key The key + * \return the corresonding element. + */ + const V operator[](const K& key) const { return this->at(key); } + /*! \return The size of the array */ + size_t size() const { + MapObj* n = GetMapObj(); + return n == nullptr ? 0 : n->size(); + } + /*! \return The number of elements of the key */ + size_t count(const K& key) const { + MapObj* n = GetMapObj(); + return n == nullptr ? 0 : GetMapObj()->count(key); + } + /*! \return whether array is empty */ + bool empty() const { return size() == 0; } + /*! \brief Release reference to all the elements */ + void clear() { + MapObj* n = GetMapObj(); + if (n != nullptr) { + data_ = MapObj::Empty(); + } + } + /*! + * \brief set the Map. + * \param key The index key. + * \param value The value to be setted. + */ + void Set(const K& key, const V& value) { + CopyOnWrite(); + MapObj::InsertMaybeReHash(MapObj::KVType(key, value), &data_); + } + /*! \return begin iterator */ + iterator begin() const { return iterator(GetMapObj()->begin()); } + /*! \return end iterator */ + iterator end() const { return iterator(GetMapObj()->end()); } + /*! \return find the key and returns the associated iterator */ + iterator find(const K& key) const { return iterator(GetMapObj()->find(key)); } + /*! \return The value associated with the key, NullOpt if not found */ + std::optional Get(const K& key) const { + MapObj::iterator iter = GetMapObj()->find(key); + if (iter == GetMapObj()->end()) { + return std::nullopt; + } + return details::AnyUnsafe::CopyFromAnyStorageAfterCheck(iter->second); + } + void erase(const K& key) { CopyOnWrite()->erase(key); } + + /*! + * \brief copy on write semantics + * Do nothing if current handle is the unique copy of the array. + * Otherwise make a new copy of the array to ensure the current handle + * hold a unique copy. + * + * \return Handle to the internal node container(which guarantees to be unique) + */ + MapObj* CopyOnWrite() { + if (data_.get() == nullptr) { + data_ = MapObj::Empty(); + } else if (!data_.unique()) { + data_ = MapObj::CopyFrom(GetMapObj()); + } + return GetMapObj(); + } + /*! \brief specify container node */ + using ContainerType = MapObj; + + /*! \brief Iterator of the hash map */ + class iterator { + public: + using iterator_category = std::bidirectional_iterator_tag; + using difference_type = int64_t; + using value_type = const std::pair; + using pointer = value_type*; + using reference = value_type; + + iterator() : itr() {} + + /*! \brief Compare iterators */ + bool operator==(const iterator& other) const { return itr == other.itr; } + /*! \brief Compare iterators */ + bool operator!=(const iterator& other) const { return itr != other.itr; } + /*! \brief De-reference iterators is not allowed */ + pointer operator->() const = delete; + /*! \brief De-reference iterators */ + reference operator*() const { + auto& kv = *itr; + return std::make_pair(details::AnyUnsafe::CopyFromAnyStorageAfterCheck(kv.first), + details::AnyUnsafe::CopyFromAnyStorageAfterCheck(kv.second)); + } + /*! \brief Prefix self increment, e.g. ++iter */ + iterator& operator++() { + ++itr; + return *this; + } + /*! \brief Suffix self increment */ + iterator operator++(int) { + iterator copy = *this; + ++(*this); + return copy; + } + + /*! \brief Prefix self decrement, e.g. --iter */ + iterator& operator--() { + --itr; + return *this; + } + /*! \brief Suffix self decrement */ + iterator operator--(int) { + iterator copy = *this; + --(*this); + return copy; + } + + private: + iterator(const MapObj::iterator& itr) // NOLINT(*) + : itr(itr) {} + + template + friend class Map; + + MapObj::iterator itr; + }; + + private: + /*! \brief Return data_ as type of pointer of MapObj */ + MapObj* GetMapObj() const { return static_cast(data_.get()); } + + template + friend class Map; +}; + +/*! + * \brief Merge two Maps. + * \param lhs the first Map to merge. + * \param rhs the second Map to merge. + * @return The merged Array. Original Maps are kept unchanged. + */ +template && + details::storage_enabled_v>> +inline Map Merge(Map lhs, const Map& rhs) { + for (const auto& p : rhs) { + lhs.Set(p.first, p.second); + } + return std::move(lhs); +} + +// Traits for Map +template +inline constexpr bool use_default_type_traits_v> = false; + +template +struct TypeTraits> : public ObjectRefTypeTraitsBase> { + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIMap; + using ObjectRefTypeTraitsBase>::CopyFromAnyStorageAfterCheck; + + static TVM_FFI_INLINE std::string GetMismatchTypeInfo(const TVMFFIAny* src) { + if (src->type_index != TypeIndex::kTVMFFIMap) { + return TypeTraitsBase::GetMismatchTypeInfo(src); + } + if constexpr (!std::is_same_v || !std::is_same_v) { + const MapObj* n = reinterpret_cast(src->v_obj); + for (const auto& kv : *n) { + if constexpr (!std::is_same_v) { + if (!details::AnyUnsafe::CheckAnyStorage(kv.first) && !kv.first.as().has_value()) { + return "Map[some key is " + details::AnyUnsafe::GetMismatchTypeInfo(kv.first) + + ", V]"; + } + } + if constexpr (!std::is_same_v) { + if (!details::AnyUnsafe::CheckAnyStorage(kv.second) && + !kv.second.as().has_value()) { + return "Map[K, some value is " + details::AnyUnsafe::GetMismatchTypeInfo(kv.second) + + "]"; + } + } + } + } + TVM_FFI_THROW(InternalError) << "Cannot reach here"; + TVM_FFI_UNREACHABLE(); + } + + static TVM_FFI_INLINE bool CheckAnyStorage(const TVMFFIAny* src) { + if (src->type_index != TypeIndex::kTVMFFIMap) return false; + if constexpr (std::is_same_v && std::is_same_v) { + return true; + } else { + const MapObj* n = reinterpret_cast(src->v_obj); + for (const auto& kv : *n) { + if constexpr (!std::is_same_v) { + if (!details::AnyUnsafe::CheckAnyStorage(kv.first)) return false; + } + if constexpr (!std::is_same_v) { + if (!details::AnyUnsafe::CheckAnyStorage(kv.second)) return false; + } + } + return true; + } + } + + static TVM_FFI_INLINE std::optional> TryConvertFromAnyView(const TVMFFIAny* src) { + if (src->type_index != TypeIndex::kTVMFFIMap) return std::nullopt; + if constexpr (!std::is_same_v || !std::is_same_v) { + const MapObj* n = reinterpret_cast(src->v_obj); + bool storage_check = [&]() { + for (const auto& kv : *n) { + if constexpr (!std::is_same_v) { + if (!details::AnyUnsafe::CheckAnyStorage(kv.first)) return false; + } + if constexpr (!std::is_same_v) { + if (!details::AnyUnsafe::CheckAnyStorage(kv.second)) return false; + } + } + return true; + }(); + // fast path, if storage check passes, we can return the array directly. + if (storage_check) return CopyFromAnyStorageAfterCheck(src); + // slow path, we need to create a new map and convert to the target type. + Map ret; + for (const auto& kv : *n) { + auto k = kv.first.as(); + auto v = kv.second.as(); + if (!k.has_value() || !v.has_value()) return std::nullopt; + ret.Set(*std::move(k), *std::move(v)); + } + return ret; + } else { + return CopyFromAnyStorageAfterCheck(src); + } + } + + static TVM_FFI_INLINE std::string TypeStr() { + return "Map<" + details::Type2Str::v() + ", " + details::Type2Str::v() + ">"; + } +}; + +namespace details { +template +inline constexpr bool type_contains_v, Map> = + type_contains_v && type_contains_v; +} // namespace details + +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_CONTAINER_MAP_H_ diff --git a/ffi/include/tvm/ffi/container/ndarray.h b/ffi/include/tvm/ffi/container/ndarray.h new file mode 100644 index 000000000000..6acdbc3a2692 --- /dev/null +++ b/ffi/include/tvm/ffi/container/ndarray.h @@ -0,0 +1,337 @@ + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ffi/ndarray.h + * \brief Container to store an NDArray. + */ +#ifndef TVM_FFI_CONTAINER_NDARRAY_H_ +#define TVM_FFI_CONTAINER_NDARRAY_H_ + +#include +#include +#include +#include + +#include + +namespace tvm { +namespace ffi { + +/*! + * \brief check if a DLTensor is contiguous. + * \param arr The input DLTensor. + * \return The check result. + */ +inline bool IsContiguous(const DLTensor& arr) { + if (arr.strides == nullptr) return true; + int64_t expected_stride = 1; + for (int32_t i = arr.ndim; i != 0; --i) { + int32_t k = i - 1; + if (arr.shape[k] == 1) { + // Skip stride check if shape[k] is 1, where the dimension is contiguous + // regardless of the value of stride. + // + // For example, PyTorch will normalize stride to 1 if shape is 1 when exporting + // to DLPack. + // More context: https://github.com/pytorch/pytorch/pull/83158 + continue; + } + if (arr.strides[k] != expected_stride) return false; + expected_stride *= arr.shape[k]; + } + return true; +} + +/** + * \brief Check if the data in the DLTensor is aligned to the given alignment. + * \param arr The input DLTensor. + * \param alignment The alignment to check. + * \return True if the data is aligned to the given alignment, false otherwise. + */ +inline bool IsAligned(const DLTensor& arr, size_t alignment) { + // whether the device uses direct address mapping instead of indirect buffer + bool direct_address = arr.device.device_type <= kDLCUDAHost || + arr.device.device_type == kDLCUDAManaged || + arr.device.device_type == kDLROCM || arr.device.device_type == kDLROCMHost; + if (direct_address) { + return (reinterpret_cast(static_cast(arr.data) + arr.byte_offset) % alignment == + 0); + } else { + return arr.byte_offset % alignment == 0; + } +} + +/*! + * \brief return the total number bytes needs to store packed data + * + * \param numel the number of elements in the array + * \param dtype the data type of the array + * \return the total number bytes needs to store packed data + */ +inline size_t GetDataSize(int64_t numel, DLDataType dtype) { + // compatible handling sub-byte uint1(bool), which usually stored as uint8_t + // TODO(tqchen): revisit and switch to kDLBool + if (dtype.code == kDLUInt && dtype.bits == 1 && dtype.lanes == 1) { + return numel; + } + // for other sub-byte types, packing is preferred + return (numel * dtype.bits * dtype.lanes + 7) / 8; +} + +/*! + * \brief return the size of data the DLTensor hold, in term of number of bytes + * + * \param arr the input DLTensor + * \return number of bytes of data in the DLTensor. + */ +inline size_t GetDataSize(const DLTensor& arr) { + size_t size = 1; + for (int i = 0; i < arr.ndim; ++i) { + size *= static_cast(arr.shape[i]); + } + return GetDataSize(size, arr.dtype); +} + +/*! \brief An object representing an NDArray. */ +class NDArrayObj : public Object, public DLTensor { + public: + static constexpr const uint32_t _type_index = TypeIndex::kTVMFFINDArray; + static constexpr const char* _type_key = StaticTypeKey::kTVMFFINDArray; + TVM_FFI_DECLARE_STATIC_OBJECT_INFO(NDArrayObj, Object); + + /*! + * \brief Move NDArray to a DLPack managed tensor. + * \return The converted DLPack managed tensor. + */ + DLManagedTensor* ToDLPack() const { + DLManagedTensor* ret = new DLManagedTensor(); + NDArrayObj* from = const_cast(this); + ret->dl_tensor = *static_cast(from); + ret->manager_ctx = from; + ret->deleter = DLManagedTensorDeleter; + details::ObjectUnsafe::IncRefObjectHandle(from); + return ret; + } + + /*! + * \brief Move NDArray to a DLPack managed tensor. + * \return The converted DLPack managed tensor. + */ + DLManagedTensorVersioned* ToDLPackVersioned() const { + DLManagedTensorVersioned* ret = new DLManagedTensorVersioned(); + NDArrayObj* from = const_cast(this); + ret->version.major = DLPACK_MAJOR_VERSION; + ret->version.minor = DLPACK_MINOR_VERSION; + ret->dl_tensor = *static_cast(from); + ret->manager_ctx = from; + ret->deleter = DLManagedTensorVersionedDeleter; + ret->flags = 0; + details::ObjectUnsafe::IncRefObjectHandle(from); + return ret; + } + + protected: + // backs up the shape of the NDArray + Optional shape_data_; + + static void DLManagedTensorDeleter(DLManagedTensor* tensor) { + NDArrayObj* obj = static_cast(tensor->manager_ctx); + details::ObjectUnsafe::DecRefObjectHandle(obj); + delete tensor; + } + + static void DLManagedTensorVersionedDeleter(DLManagedTensorVersioned* tensor) { + NDArrayObj* obj = static_cast(tensor->manager_ctx); + details::ObjectUnsafe::DecRefObjectHandle(obj); + delete tensor; + } + + friend class NDArray; +}; + +namespace details { +/*! + *\brief Helper class to create an NDArrayObj from an NDAllocator + * + * The underlying allocator needs to be implemented by user. + */ +template +class NDArrayObjFromNDAlloc : public NDArrayObj { + public: + template + NDArrayObjFromNDAlloc(TNDAlloc alloc, ffi::Shape shape, DLDataType dtype, DLDevice device, + ExtraArgs&&... extra_args) + : alloc_(alloc) { + this->device = device; + this->ndim = static_cast(shape.size()); + this->dtype = dtype; + this->shape = const_cast(shape.data()); + this->strides = nullptr; + this->byte_offset = 0; + this->shape_data_ = std::move(shape); + alloc_.AllocData(static_cast(this), std::forward(extra_args)...); + } + + ~NDArrayObjFromNDAlloc() { alloc_.FreeData(static_cast(this)); } + + private: + TNDAlloc alloc_; +}; + +/*! \brief helper class to import from DLPack legacy DLManagedTensor */ +template +class NDArrayObjFromDLPack : public NDArrayObj { + public: + explicit NDArrayObjFromDLPack(TDLPackManagedTensor* tensor) : tensor_(tensor) { + *static_cast(this) = tensor_->dl_tensor; + // set strides to nullptr if the tensor is contiguous. + if (IsContiguous(tensor->dl_tensor)) { + this->strides = nullptr; + } + } + + ~NDArrayObjFromDLPack() { + // run DLPack deleter if needed. + if (tensor_->deleter != nullptr) { + (*tensor_->deleter)(tensor_); + } + } + + private: + TDLPackManagedTensor* tensor_; +}; +} // namespace details + +/*! + * \brief Managed NDArray. + * The array is backed by reference counted blocks. + * + * \note This class can be subclassed to implement downstream customized + * NDArray types that are backed by the same NDArrayObj storage type. + */ +class NDArray : public ObjectRef { + public: + /*! + * \brief Get the shape of the NDArray. + * \return The shape of the NDArray. + */ + tvm::ffi::Shape shape() const { + NDArrayObj* obj = get_mutable(); + if (!obj->shape_data_.has_value()) { + obj->shape_data_ = tvm::ffi::Shape(obj->shape, obj->shape + obj->ndim); + } + return *(obj->shape_data_); + } + /*! + * \brief Get the data type of the NDArray. + * \return The data type of the NDArray. + */ + DLDataType dtype() const { return (*this)->dtype; } + /*! + * \brief Check if the NDArray is contiguous. + * \return True if the NDArray is contiguous, false otherwise. + */ + bool IsContiguous() const { return tvm::ffi::IsContiguous(*get()); } + /*! + * \brief Create a NDArray from a NDAllocator. + * \param alloc The NDAllocator. + * \param shape The shape of the NDArray. + * \param dtype The data type of the NDArray. + * \param device The device of the NDArray. + * \return The created NDArray. + * \tparam TNDAlloc The type of the NDAllocator, impelments Alloc and Free. + * \tparam ExtraArgs Extra arguments to be passed to Alloc. + */ + template + static NDArray FromNDAlloc(TNDAlloc alloc, ffi::Shape shape, DLDataType dtype, DLDevice device, + ExtraArgs&&... extra_args) { + return NDArray(make_object>( + alloc, shape, dtype, device, std::forward(extra_args)...)); + } + + /*! + * \brief Create a NDArray from a DLPack managed tensor, pre v1.0 API. + * \param tensor The input DLPack managed tensor. + * \param require_alignment The minimum alignment requored of the data + byte_offset. + * \param require_contiguous Boolean flag indicating if we need to check for contiguity. + * \note This function will not run any checks on flags. + * \return The created NDArray. + */ + static NDArray FromDLPack(DLManagedTensor* tensor, size_t require_alignment = 0, + bool require_contiguous = false) { + if (require_alignment != 0 && !ffi::IsAligned(tensor->dl_tensor, require_alignment)) { + TVM_FFI_THROW(RuntimeError) << "FromDLPack: Data is not aligned to " << require_alignment + << " bytes."; + } + if (require_contiguous && !ffi::IsContiguous(tensor->dl_tensor)) { + TVM_FFI_THROW(RuntimeError) << "FromDLPack: Tensor is not contiguous."; + } + return NDArray(make_object>(tensor)); + } + + /*! + * \brief Create a NDArray from a DLPack managed tensor, post v1.0 API. + * \param tensor The input DLPack managed tensor. + * \param require_alignment The minimum alignment requored of the data + byte_offset. + * \param require_contiguous Boolean flag indicating if we need to check for contiguity. + * \return The created NDArray. + */ + static NDArray FromDLPackVersioned(DLManagedTensorVersioned* tensor, size_t require_alignment = 0, + bool require_contiguous = false) { + if (require_alignment != 0 && !ffi::IsAligned(tensor->dl_tensor, require_alignment)) { + TVM_FFI_THROW(RuntimeError) << "FromDLPack: Data is not aligned to " << require_alignment + << " bytes."; + } + if (require_contiguous && !ffi::IsContiguous(tensor->dl_tensor)) { + TVM_FFI_THROW(RuntimeError) << "FromDLPack: Tensor is not contiguous."; + } + if (tensor->flags & DLPACK_FLAG_BITMASK_IS_SUBBYTE_TYPE_PADDED) { + TVM_FFI_THROW(RuntimeError) << "Subbyte type padded is not yet supported"; + } + return NDArray(make_object>(tensor)); + } + + /*! + * \brief Convert the NDArray to a DLPack managed tensor. + * \return The converted DLPack managed tensor. + */ + DLManagedTensor* ToDLPack() const { return get_mutable()->ToDLPack(); } + + /*! + * \brief Convert the NDArray to a DLPack managed tensor. + * \return The converted DLPack managed tensor. + */ + DLManagedTensorVersioned* ToDLPackVersioned() const { return get_mutable()->ToDLPackVersioned(); } + + TVM_FFI_DEFINE_OBJECT_REF_METHODS(NDArray, ObjectRef, NDArrayObj); + + protected: + /*! + * \brief Get mutable internal container pointer. + * \return a mutable container pointer. + */ + NDArrayObj* get_mutable() const { return const_cast(get()); } +}; + +} // namespace ffi +} // namespace tvm + +#endif // TVM_FFI_CONTAINER_NDARRAY_H_ diff --git a/ffi/include/tvm/ffi/container/shape.h b/ffi/include/tvm/ffi/container/shape.h new file mode 100644 index 000000000000..7c4839d226f0 --- /dev/null +++ b/ffi/include/tvm/ffi/container/shape.h @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ffi/shape.h + * \brief Container to store shape of an NDArray. + */ +#ifndef TVM_FFI_CONTAINER_SHAPE_H_ +#define TVM_FFI_CONTAINER_SHAPE_H_ + +#include +#include +#include + +#include +#include +#include +#include + +namespace tvm { +namespace ffi { + +/*! \brief An object representing a shape tuple. */ +class ShapeObj : public Object, public TVMFFIShapeCell { + public: + using index_type = int64_t; + + /*! \brief Get "numel", meaning the number of elements of an array if the array has this shape */ + int64_t Product() const { + int64_t product = 1; + for (size_t i = 0; i < this->size; ++i) { + product *= this->data[i]; + } + return product; + } + + static constexpr const uint32_t _type_index = TypeIndex::kTVMFFIShape; + static constexpr const char* _type_key = StaticTypeKey::kTVMFFIShape; + TVM_FFI_DECLARE_STATIC_OBJECT_INFO(ShapeObj, Object); +}; + +namespace details { + +class ShapeObjStdImpl : public ShapeObj { + public: + explicit ShapeObjStdImpl(std::vector other) : data_{other} { + this->data = data_.data(); + this->size = static_cast(data_.size()); + } + + private: + std::vector data_; +}; + +TVM_FFI_INLINE ObjectPtr MakeEmptyShape(size_t length, int64_t** mutable_data) { + ObjectPtr p = make_inplace_array_object(length); + static_assert(alignof(ShapeObj) % alignof(int64_t) == 0); + static_assert(sizeof(ShapeObj) % alignof(int64_t) == 0); + int64_t* data = reinterpret_cast(reinterpret_cast(p.get()) + sizeof(ShapeObj)); + if (mutable_data) { + *mutable_data = data; + } + p->data = data; + p->size = length; + return p; +} + +// inplace shape allocation +template +TVM_FFI_INLINE ObjectPtr MakeInplaceShape(IterType begin, IterType end) { + size_t length = std::distance(begin, end); + int64_t* mutable_data; + ObjectPtr p = MakeEmptyShape(length, &mutable_data); + std::copy(begin, end, mutable_data); + return p; +} + +} // namespace details + +/*! + * \brief Reference to shape object. + */ +class Shape : public ObjectRef { + public: + /*! \brief The type of shape index element. */ + using index_type = ShapeObj::index_type; + + /*! \brief Default constructor */ + Shape() : ObjectRef(details::MakeEmptyShape(0, nullptr)) {} + + /*! + * \brief Constructor from iterator + * \param begin begin of iterator + * \param end end of iterator + * \tparam IterType The type of iterator + */ + template + Shape(IterType begin, IterType end) : Shape(details::MakeInplaceShape(begin, end)) {} + + /** + * \brief Constructor from Array + * \param shape The Array + * + * \note This constructor will copy the data content. + */ + Shape(Array shape) // NOLINT(*) + : Shape(shape.begin(), shape.end()) {} + + /*! + * \brief constructor from initializer list + * \param shape The initializer list + */ + Shape(std::initializer_list shape) : Shape(shape.begin(), shape.end()) {} + + /*! + * \brief constructor from int64_t [N] + * + * \param other a int64_t array. + */ + Shape(std::vector other) // NOLINT(*) + : ObjectRef(make_object(std::move(other))) {} + + /*! + * \brief Return the data pointer + * + * \return const index_type* data pointer + */ + const int64_t* data() const { return get()->data; } + + /*! + * \brief Return the size of the shape tuple + * + * \return size_t shape tuple size + */ + size_t size() const { return get()->size; } + + /*! + * \brief Immutably read i-th element from the shape tuple. + * \param idx The index + * \return the i-th element. + */ + int64_t operator[](size_t idx) const { + if (idx >= this->size()) { + TVM_FFI_THROW(IndexError) << "indexing " << idx << " on a Shape of size " << this->size(); + } + return this->data()[idx]; + } + + /*! + * \brief Immutably read i-th element from the shape tuple. + * \param idx The index + * \return the i-th element. + */ + int64_t at(size_t idx) const { return this->operator[](idx); } + + /*! \return Whether shape tuple is empty */ + bool empty() const { return size() == 0; } + + /*! \return The first element of the shape tuple */ + int64_t front() const { return this->at(0); } + + /*! \return The last element of the shape tuple */ + int64_t back() const { return this->at(this->size() - 1); } + + /*! \return begin iterator */ + const int64_t* begin() const { return get()->data; } + + /*! \return end iterator */ + const int64_t* end() const { return (get()->data + size()); } + + /*! \return The product of the shape tuple */ + int64_t Product() const { return get()->Product(); } + + TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Shape, ObjectRef, ShapeObj); +}; + +inline std::ostream& operator<<(std::ostream& os, const Shape& shape) { + os << '['; + for (size_t i = 0; i < shape.size(); ++i) { + if (i != 0) { + os << ", "; + } + os << shape[i]; + } + os << ']'; + return os; +} + +// Shape +template <> +inline constexpr bool use_default_type_traits_v = false; + +// Allow auto conversion from Array to Shape, but not from Shape to Array +template <> +struct TypeTraits : public ObjectRefWithFallbackTraitsBase> { + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIShape; + static TVM_FFI_INLINE Shape ConvertFallbackValue(Array src) { return Shape(src); } +}; + +} // namespace ffi +} // namespace tvm + +#endif // TVM_FFI_CONTAINER_SHAPE_H_ diff --git a/ffi/include/tvm/ffi/container/tuple.h b/ffi/include/tvm/ffi/container/tuple.h new file mode 100644 index 000000000000..63c36467fe3d --- /dev/null +++ b/ffi/include/tvm/ffi/container/tuple.h @@ -0,0 +1,279 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ffi/container/tuple.h + * \brief Typed tuple like std::tuple backed by ArrayObj container. + */ +#ifndef TVM_FFI_CONTAINER_TUPLE_H_ +#define TVM_FFI_CONTAINER_TUPLE_H_ + +#include + +#include +#include +#include + +namespace tvm { +namespace ffi { + +/*! + * \brief Typed tuple like std::tuple backed by ArrayObj container. + * + * Tuple implements in-place copy-on-write semantics. + * + * \tparam Types The types of the tuple elements + */ +template +class Tuple : public ObjectRef { + public: + static_assert(details::all_storage_enabled_v, + "All types used in Tuple<...> must be compatible with Any"); + + Tuple() : ObjectRef(MakeDefaultTupleNode()) {} + Tuple(const Tuple& other) : ObjectRef(other) {} + Tuple(Tuple&& other) : ObjectRef(std::move(other)) {} + template && ...), int>> + Tuple(const Tuple& other) : ObjectRef(other) {} + template && ...), int>> + Tuple(Tuple&& other) : ObjectRef(std::move(other)) {} + template + explicit Tuple(UTypes&&... args) : ObjectRef(MakeTupleNode(std::forward(args)...)) { + static_assert(sizeof...(Types) == sizeof...(UTypes), "Tuple size mismatch"); + } + + TVM_FFI_INLINE Tuple& operator=(const Tuple& other) { + data_ = other.data_; + return *this; + } + + TVM_FFI_INLINE Tuple& operator=(Tuple&& other) { + data_ = std::move(other.data_); + return *this; + } + + template && ...)>> + TVM_FFI_INLINE Tuple& operator=(const Tuple& other) { + data_ = other.data_; + return *this; + } + + template && ...)>> + TVM_FFI_INLINE Tuple& operator=(Tuple&& other) { + data_ = std::move(other.data_); + return *this; + } + + explicit Tuple(ObjectPtr n) : ObjectRef(n) {} + + /*! + * \brief Get I-th element of the tuple + * + * \tparam I The index of the element to get + * \return The I-th element of the tuple + * \note We use stl style since get usually is like a getter. + */ + template + auto get() const { + static_assert(I < sizeof...(Types), "Tuple index out of bounds"); + using ReturnType = std::tuple_element_t>; + const Any* ptr = GetArrayObj()->begin() + I; + return details::AnyUnsafe::CopyFromAnyStorageAfterCheck(*ptr); + } + + /*! + * \brief Set I-th element of the tuple + * + * \param item The item to set + * \tparam I The index of the element to set + * \tparam U The type of the item + * + * \note This function will perform copy on write if underlying + * container is not uniquely owned. + * We use CamelCase since Set can cause copy on write + * and is more complicated than simple field setter. + */ + template + void Set(U&& item) { + static_assert(I < sizeof...(Types), "Tuple index out of bounds"); + using T = std::tuple_element_t>; + this->CopyIfNotUnique(); + Any* ptr = GetArrayObj()->MutableBegin() + I; + *ptr = T(std::forward(item)); + } + + /*! \brief specify container node */ + using ContainerType = ArrayObj; + + private: + static ObjectPtr MakeDefaultTupleNode() { + ObjectPtr p = make_inplace_array_object(sizeof...(Types)); + p->capacity_ = sizeof...(Types); + // immeidate set size to 0, to ensure exception safety + p->size_ = 0; + Any* itr = p->MutableBegin(); + // increase size after each new to ensure exception safety + ((new (itr++) Any(Types()), p->size_++), ...); + return p; + } + + template + static ObjectPtr MakeTupleNode(UTypes&&... args) { + ObjectPtr p = make_inplace_array_object(sizeof...(Types)); + p->capacity_ = sizeof...(Types); + // immeidate set size to 0, to ensure exception safety + p->size_ = 0; + Any* itr = p->MutableBegin(); + // increase size after each new to ensure exception safety + ((new (itr++) Any(Types(std::forward(args))), p->size_++), ...); + return p; + } + + /*! \brief Copy on write */ + void CopyIfNotUnique() { + if (!data_.unique()) { + ObjectPtr p = make_inplace_array_object(sizeof...(Types)); + p->capacity_ = sizeof...(Types); + // immeidate set size to 0, to ensure exception safety + p->size_ = 0; + Any* itr = p->MutableBegin(); + const Any* read = GetArrayObj()->begin(); + // increase size after each new to ensure exception safety + for (size_t i = 0; i < sizeof...(Types); ++i) { + new (itr++) Any(*read++); + p->size_++; + } + data_ = std::move(p); + } + } + + /*! \return The underlying ArrayObj */ + ArrayObj* GetArrayObj() const { return static_cast(data_.get()); } + + template + friend class Tuple; +}; + +template +inline constexpr bool use_default_type_traits_v> = false; + +template +struct TypeTraits> : public ObjectRefTypeTraitsBase> { + using ObjectRefTypeTraitsBase>::CopyFromAnyStorageAfterCheck; + + static TVM_FFI_INLINE std::string GetMismatchTypeInfo(const TVMFFIAny* src) { + if (src->type_index != TypeIndex::kTVMFFIArray) { + return TypeTraitsBase::GetMismatchTypeInfo(src); + } + const ArrayObj* n = reinterpret_cast(src->v_obj); + if (n->size() != sizeof...(Types)) { + return "Array[size=" + std::to_string(n->size()) + "]"; + } + return GetMismatchTypeInfoHelper<0, Types...>(n->begin()); + } + + template + static TVM_FFI_INLINE std::string GetMismatchTypeInfoHelper(const Any* arr) { + if constexpr (!std::is_same_v) { + const Any& any_v = arr[I]; + if (!details::AnyUnsafe::CheckAnyStorage(any_v) && !(any_v.as().has_value())) { + // now report the accurate mismatch information + return "Array[index " + std::to_string(I) + ": " + + details::AnyUnsafe::GetMismatchTypeInfo(any_v) + "]"; + } + } + if constexpr (sizeof...(Rest) > 0) { + return GetMismatchTypeInfoHelper(arr); + } + TVM_FFI_THROW(InternalError) << "Cannot reach here"; + TVM_FFI_UNREACHABLE(); + } + + static TVM_FFI_INLINE bool CheckAnyStorage(const TVMFFIAny* src) { + if (src->type_index != TypeIndex::kTVMFFIArray) return false; + const ArrayObj* n = reinterpret_cast(src->v_obj); + if (n->size() != sizeof...(Types)) return false; + const TVMFFIAny* ffi_any_arr = reinterpret_cast(n->begin()); + return CheckAnyStorageHelper<0, Types...>(ffi_any_arr); + } + + template + static TVM_FFI_INLINE bool CheckAnyStorageHelper(const TVMFFIAny* src_arr) { + if constexpr (!std::is_same_v) { + if (!TypeTraits::CheckAnyStorage(src_arr + I)) { + return false; + } + } + if constexpr (sizeof...(Rest) > 0) { + return CheckAnyStorageHelper(src_arr); + } + return true; + } + + static TVM_FFI_INLINE std::optional> TryConvertFromAnyView( + const TVMFFIAny* src // + ) { + if (src->type_index != TypeIndex::kTVMFFIArray) return std::nullopt; + const ArrayObj* n = reinterpret_cast(src->v_obj); + if (n->size() != sizeof...(Types)) return std::nullopt; + // fast path, storage is already in the right type + if (CheckAnyStorage(src)) { + return CopyFromAnyStorageAfterCheck(src); + } + // slow path, try to convert to each type to match the tuple storage need. + Array arr = TypeTraits>::CopyFromAnyStorageAfterCheck(src); + Any* ptr = arr.CopyOnWrite()->MutableBegin(); + if (TryConvertElements<0, Types...>(ptr)) { + return Tuple(details::ObjectUnsafe::ObjectPtrFromObjectRef(arr)); + } + return std::nullopt; + } + + template + static TVM_FFI_INLINE bool TryConvertElements(Any* arr) { + if constexpr (!std::is_same_v) { + if (auto opt_convert = arr[I].as()) { + arr[I] = *std::move(opt_convert); + } else { + return false; + } + } + if constexpr (sizeof...(Rest) > 0) { + return TryConvertElements(std::move(arr)); + } + return true; + } + + static TVM_FFI_INLINE std::string TypeStr() { + return details::ContainerTypeStr("Tuple"); + } +}; + +namespace details { +template +inline constexpr bool type_contains_v, Tuple> = (type_contains_v && ...); +} // namespace details + +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_CONTAINER_TUPLE_H_ diff --git a/ffi/include/tvm/ffi/container/variant.h b/ffi/include/tvm/ffi/container/variant.h new file mode 100644 index 000000000000..c8b58ba49e39 --- /dev/null +++ b/ffi/include/tvm/ffi/container/variant.h @@ -0,0 +1,198 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ffi/container/variant.h + * \brief Runtime variant container types. + */ +#ifndef TVM_FFI_CONTAINER_VARIANT_H_ +#define TVM_FFI_CONTAINER_VARIANT_H_ + +#include +#include +#include + +#include +#include +#include + +namespace tvm { +namespace ffi { + +/*! + * \brief A typed variant container. + * + * A Variant is backed by Any container, with strong checks during construction. + */ +template +class Variant { + public: + static_assert(details::all_storage_enabled_v, + "All types used in Variant<...> must be compatible with Any"); + /* + * \brief Helper utility to check if the type can be contained in the variant + */ + template + static constexpr bool variant_contains_v = (details::type_contains_v || ...); + /* \brief Helper utility for SFINAE if the type is part of the variant */ + template + using enable_if_variant_contains_t = std::enable_if_t>; + + Variant(const Variant& other) : data_(other.data_) {} + Variant(Variant&& other) : data_(std::move(other.data_)) {} + + TVM_FFI_INLINE Variant& operator=(const Variant& other) { + data_ = other.data_; + return *this; + } + + TVM_FFI_INLINE Variant& operator=(Variant&& other) { + data_ = std::move(other.data_); + return *this; + } + + template > + Variant(T other) : data_(std::move(other)) {} // NOLINT(*) + + template > + TVM_FFI_INLINE Variant& operator=(T other) { + data_ = std::move(other); + return *this; + } + + template > + TVM_FFI_INLINE std::optional as() const { + return data_.as(); + } + + /* + * \brief Shortcut of as Object to cast to a const pointer when T is an Object. + * + * \tparam T The object type. + * \return The requested pointer, returns nullptr if type mismatches. + */ + template >> + TVM_FFI_INLINE const T* as() const { + return data_.as().value_or(nullptr); + } + + template > + TVM_FFI_INLINE T get() const& { + return data_.template cast(); + } + + template > + TVM_FFI_INLINE T get() && { + return std::move(data_).template cast(); + } + + TVM_FFI_INLINE std::string GetTypeKey() const { return data_.GetTypeKey(); } + + private: + friend struct TypeTraits>; + friend struct ObjectPtrHash; + friend struct ObjectPtrEqual; + // constructor from any + explicit Variant(Any data) : data_(std::move(data)) {} + // internal data is backed by Any + Any data_; + /*! + * \brief Get the object pointer from the variant + * \note This function is only available if all types used in Variant<...> are derived from + * ObjectRef + */ + TVM_FFI_INLINE Object* GetObjectPtrForHashEqual() const { + constexpr bool all_object_v = (std::is_base_of_v && ...); + static_assert(all_object_v, + "All types used in Variant<...> must be derived from ObjectRef " + "to enable ObjectPtrHash/ObjectPtrEqual"); + return details::AnyUnsafe::ObjectPtrFromAnyAfterCheck(data_); + } +}; + +template +inline constexpr bool use_default_type_traits_v> = false; + +template +struct TypeTraits> : public TypeTraitsBase { + static TVM_FFI_INLINE void CopyToAnyView(const Variant& src, TVMFFIAny* result) { + *result = AnyView(src.data_).CopyToTVMFFIAny(); + } + + static TVM_FFI_INLINE void MoveToAny(Variant src, TVMFFIAny* result) { + *result = details::AnyUnsafe::MoveAnyToTVMFFIAny(std::move(src.data_)); + } + + static TVM_FFI_INLINE std::string GetMismatchTypeInfo(const TVMFFIAny* src) { + return TypeTraitsBase::GetMismatchTypeInfo(src); + } + + static TVM_FFI_INLINE bool CheckAnyStorage(const TVMFFIAny* src) { + return (TypeTraits::CheckAnyStorage(src) || ...); + } + + static TVM_FFI_INLINE Variant CopyFromAnyStorageAfterCheck(const TVMFFIAny* src) { + return Variant(Any(AnyView::CopyFromTVMFFIAny(*src))); + } + + static TVM_FFI_INLINE Variant MoveFromAnyStorageAfterCheck(TVMFFIAny* src) { + return Variant(details::AnyUnsafe::MoveTVMFFIAnyToAny(std::move(*src))); + } + + static TVM_FFI_INLINE std::optional> TryConvertFromAnyView(const TVMFFIAny* src) { + // fast path, storage is already in the right type + if (CheckAnyStorage(src)) { + return CopyFromAnyStorageAfterCheck(src); + } + // More expensive path, try to convert to each type, in order of declaration + return TryVariantTypes(src); + } + + template + static TVM_FFI_INLINE std::optional> TryVariantTypes(const TVMFFIAny* src) { + if (auto opt_convert = TypeTraits::TryConvertFromAnyView(src)) { + return Variant(*std::move(opt_convert)); + } + if constexpr (sizeof...(Rest) > 0) { + return TryVariantTypes(src); + } + return std::nullopt; + } + + static TVM_FFI_INLINE std::string TypeStr() { return details::ContainerTypeStr("Variant"); } +}; + +template +TVM_FFI_INLINE size_t ObjectPtrHash::operator()(const Variant& a) const { + return std::hash()(a.GetObjectPtrForHashEqual()); +} + +template +TVM_FFI_INLINE bool ObjectPtrEqual::operator()(const Variant& a, + const Variant& b) const { + return a.GetObjectPtrForHashEqual() == b.GetObjectPtrForHashEqual(); +} + +namespace details { +template +inline constexpr bool type_contains_v, T> = (type_contains_v || ...); +} // namespace details +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_CONTAINER_VARIANT_H_ diff --git a/ffi/include/tvm/ffi/dtype.h b/ffi/include/tvm/ffi/dtype.h new file mode 100644 index 000000000000..99eb227ee1af --- /dev/null +++ b/ffi/include/tvm/ffi/dtype.h @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ffi/dtype.h + * \brief Data type handling. + */ +#ifndef TVM_FFI_DTYPE_H_ +#define TVM_FFI_DTYPE_H_ + +#include +#include +#include +#include +#include + +#include + +namespace tvm { +namespace ffi { +/*! + * \brief Extension code beyond the DLDataType. + * + * This class is always consistent with the DLPack. + * + * TOTO(tvm-team): update to latest DLPack types. + */ +enum DLExtDataTypeCode { kDLExtCustomBegin = 129 }; + +namespace details { + +/* + * \brief Convert a DLDataTypeCode to a string. + * \param os The output stream. + * \param type_code The DLDataTypeCode to convert. + */ +inline const char* DLDataTypeCodeAsCStr(DLDataTypeCode type_code) { // NOLINT(*) + switch (static_cast(type_code)) { + case kDLInt: { + return "int"; + } + case kDLUInt: { + return "uint"; + } + case kDLFloat: { + return "float"; + } + case kDLOpaqueHandle: { + return "handle"; + } + case kDLBfloat: { + return "bfloat"; + } + case kDLFloat8_e3m4: { + return "float8_e3m4"; + } + case kDLFloat8_e4m3: { + return "float8_e4m3"; + } + case kDLFloat8_e4m3b11fnuz: { + return "float8_e4m3b11fnuz"; + } + case kDLFloat8_e4m3fn: { + return "float8_e4m3fn"; + } + case kDLFloat8_e4m3fnuz: { + return "float8_e4m3fnuz"; + } + case kDLFloat8_e5m2: { + return "float8_e5m2"; + } + case kDLFloat8_e5m2fnuz: { + return "float8_e5m2fnuz"; + } + case kDLFloat8_e8m0fnu: { + return "float8_e8m0fnu"; + } + case kDLFloat6_e2m3fn: { + return "float6_e2m3fn"; + } + case kDLFloat6_e3m2fn: { + return "float6_e3m2fn"; + } + case kDLFloat4_e2m1fn: { + return "float4_e2m1fn"; + } + default: { + if (static_cast(type_code) >= static_cast(DLExtDataTypeCode::kDLExtCustomBegin)) { + return "custom"; + } else { + TVM_FFI_THROW(ValueError) << "DLDataType contains unknown type_code=" + << static_cast(type_code); + } + TVM_FFI_UNREACHABLE(); + } + } +} +} // namespace details + +inline DLDataType StringToDLDataType(const String& str) { + DLDataType out; + TVM_FFI_CHECK_SAFE_CALL(TVMFFIDataTypeFromString(str.get(), &out)); + return out; +} + +inline String DLDataTypeToString(DLDataType dtype) { + TVMFFIObjectHandle out; + TVM_FFI_CHECK_SAFE_CALL(TVMFFIDataTypeToString(dtype, &out)); + return String(details::ObjectUnsafe::ObjectPtrFromOwned(static_cast(out))); +} + +// DLDataType +template <> +struct TypeTraits : public TypeTraitsBase { + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIDataType; + + static TVM_FFI_INLINE void CopyToAnyView(const DLDataType& src, TVMFFIAny* result) { + result->type_index = TypeIndex::kTVMFFIDataType; + result->v_dtype = src; + } + + static TVM_FFI_INLINE void MoveToAny(DLDataType src, TVMFFIAny* result) { + result->type_index = TypeIndex::kTVMFFIDataType; + result->v_dtype = src; + } + + static TVM_FFI_INLINE bool CheckAnyStorage(const TVMFFIAny* src) { + return src->type_index == TypeIndex::kTVMFFIDataType; + } + + static TVM_FFI_INLINE DLDataType CopyFromAnyStorageAfterCheck(const TVMFFIAny* src) { + return src->v_dtype; + } + + static TVM_FFI_INLINE std::optional TryConvertFromAnyView(const TVMFFIAny* src) { + if (src->type_index == TypeIndex::kTVMFFIDataType) { + return src->v_dtype; + } + // enable string to dtype auto conversion + if (auto opt_str = TypeTraits::TryConvertFromAnyView(src)) { + return StringToDLDataType(*opt_str); + } + return std::nullopt; + } + + static TVM_FFI_INLINE std::string TypeStr() { return ffi::StaticTypeKey::kTVMFFIDataType; } +}; +} // namespace ffi +} // namespace tvm + +// define DLDataType comparison and printing in root namespace +inline std::ostream& operator<<(std::ostream& os, DLDataType dtype) { // NOLINT(*) + return os << tvm::ffi::DLDataTypeToString(dtype); +} + +inline bool operator==(const DLDataType& lhs, const DLDataType& rhs) { + return lhs.code == rhs.code && lhs.bits == rhs.bits && lhs.lanes == rhs.lanes; +} + +inline bool operator!=(const DLDataType& lhs, const DLDataType& rhs) { return !(lhs == rhs); } +#endif // TVM_FFI_DTYPE_H_ diff --git a/ffi/include/tvm/ffi/endian.h b/ffi/include/tvm/ffi/endian.h new file mode 100644 index 000000000000..4a73b82e6c30 --- /dev/null +++ b/ffi/include/tvm/ffi/endian.h @@ -0,0 +1,89 @@ + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * \file tvm/ffi/endian.h + * \brief Endian detection and handling + */ +#ifndef TVM_FFI_ENDIAN_H_ +#define TVM_FFI_ENDIAN_H_ + +#include +#include + +#ifndef TVM_FFI_IO_USE_LITTLE_ENDIAN +#define TVM_FFI_IO_USE_LITTLE_ENDIAN 1 +#endif + +#ifdef TVM_FFI_CMAKE_LITTLE_ENDIAN +// If compiled with CMake, use CMake's endian detection logic +#define TVM_FFI_LITTLE_ENDIAN TVM_FFI_CMAKE_LITTLE_ENDIAN +#else +#if defined(__APPLE__) || defined(_WIN32) +#define TVM_FFI_LITTLE_ENDIAN 1 +#elif defined(__GLIBC__) || defined(__GNU_LIBRARY__) || defined(__ANDROID__) || defined(__RISCV__) +#include +#define TVM_FFI_LITTLE_ENDIAN (__BYTE_ORDER == __LITTLE_ENDIAN) +#elif defined(__FreeBSD__) || defined(__OpenBSD__) +#include +#define TVM_FFI_LITTLE_ENDIAN (_BYTE_ORDER == _LITTLE_ENDIAN) +#elif defined(__QNX__) +#include +#define TVM_FFI_LITTLE_ENDIAN (BYTE_ORDER == LITTLE_ENDIAN) +#elif defined(__EMSCRIPTEN__) || defined(__hexagon__) +#define TVM_FFI_LITTLE_ENDIAN 1 +#elif defined(__sun) || defined(sun) +#include +#if defined(_LITTLE_ENDIAN) +#define TVM_FFI_LITTLE_ENDIAN 1 +#else +#define TVM_FFI_LITTLE_ENDIAN 0 +#endif +#else +#error "Unable to determine endianness of your machine; use CMake to compile" +#endif +#endif + +/*! \brief whether serialize using little endian */ +#define TVM_FFI_IO_NO_ENDIAN_SWAP (TVM_FFI_LITTLE_ENDIAN == TVM_FFI_IO_USE_LITTLE_ENDIAN) + +namespace tvm { +namespace ffi { +/*! + * \brief A generic inplace byte swapping function. + * \param data The data pointer. + * \param elem_bytes The number of bytes of the data elements + * \param num_elems Number of elements in the data. + * \note Always try pass in constant elem_bytes to enable + * compiler optimization + */ +inline void ByteSwap(void* data, size_t elem_bytes, size_t num_elems) { + for (size_t i = 0; i < num_elems; ++i) { + uint8_t* bptr = reinterpret_cast(data) + elem_bytes * i; + for (size_t j = 0; j < elem_bytes / 2; ++j) { + uint8_t v = bptr[elem_bytes - 1 - j]; + bptr[elem_bytes - 1 - j] = bptr[j]; + bptr[j] = v; + } + } +} +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_ENDIAN_H_ diff --git a/ffi/include/tvm/ffi/error.h b/ffi/include/tvm/ffi/error.h new file mode 100644 index 000000000000..239a0e500b73 --- /dev/null +++ b/ffi/include/tvm/ffi/error.h @@ -0,0 +1,294 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * \file tvm/ffi/error.h + * \brief Error handling component. + */ +#ifndef TVM_FFI_ERROR_H_ +#define TVM_FFI_ERROR_H_ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +/*! + * \brief Macro defines whether we enable libbacktrace + */ +#ifndef TVM_FFI_USE_LIBBACKTRACE +#define TVM_FFI_USE_LIBBACKTRACE 1 +#endif + +/*! + * \brief Macro defines whether to install signal handler + * and print backtrace during segfault + */ +#ifndef TVM_FFI_BACKTRACE_ON_SEGFAULT +#define TVM_FFI_BACKTRACE_ON_SEGFAULT 1 +#endif + +namespace tvm { +namespace ffi { + +/*! + * \brief Error already set in frontend env. + * + * This error can be thrown by EnvCheckSignals to indicate + * that there is an error set in the frontend environment(e.g. + * python interpreter). The TVM FFI should catch this error + * and return a proper code tell the frontend caller about + * this fact. + * + * \code + * + * void ExampleLongRunningFunction() { + * if (TVMFFIEnvCheckSignals() != 0) { + * throw ::tvm::ffi::EnvErrorAlreadySet(); + * } + * // do work here + * } + * + * \endcode + */ +struct EnvErrorAlreadySet : public std::exception {}; + +/*! + * \brief Error object class. + */ +class ErrorObj : public Object, public TVMFFIErrorCell { + public: + static constexpr const int32_t _type_index = TypeIndex::kTVMFFIError; + static constexpr const char* _type_key = "object.Error"; + + TVM_FFI_DECLARE_STATIC_OBJECT_INFO(ErrorObj, Object); +}; + +namespace details { +class ErrorObjFromStd : public ErrorObj { + public: + ErrorObjFromStd(std::string kind, std::string message, std::string traceback) + : kind_data_(kind), message_data_(message), traceback_data_(traceback) { + this->kind = TVMFFIByteArray{kind_data_.data(), kind_data_.length()}; + this->message = TVMFFIByteArray{message_data_.data(), message_data_.length()}; + this->traceback = TVMFFIByteArray{traceback_data_.data(), traceback_data_.length()}; + this->update_traceback = UpdateTraceback; + } + + private: + /*! + * \brief Update the traceback of the error object. + * \param traceback The traceback to update. + */ + static void UpdateTraceback(TVMFFIObjectHandle self, const TVMFFIByteArray* traceback_str) { + ErrorObjFromStd* obj = static_cast(self); + obj->traceback_data_ = std::string(traceback_str->data, traceback_str->size); + obj->traceback = TVMFFIByteArray{obj->traceback_data_.data(), obj->traceback_data_.length()}; + } + + std::string kind_data_; + std::string message_data_; + std::string traceback_data_; +}; +} // namespace details + +/*! + * \brief Managed reference to ErrorObj + * \sa Error Object + */ +class Error : public ObjectRef, public std::exception { + public: + Error(std::string kind, std::string message, std::string traceback) { + data_ = make_object(kind, message, traceback); + } + + Error(std::string kind, std::string message, const TVMFFIByteArray* traceback) + : Error(kind, message, std::string(traceback->data, traceback->size)) {} + + std::string kind() const { + ErrorObj* obj = static_cast(data_.get()); + return std::string(obj->kind.data, obj->kind.size); + } + + std::string message() const { + ErrorObj* obj = static_cast(data_.get()); + return std::string(obj->message.data, obj->message.size); + } + + std::string traceback() const { + ErrorObj* obj = static_cast(data_.get()); + return std::string(obj->traceback.data, obj->traceback.size); + } + + void UpdateTraceback(const TVMFFIByteArray* traceback_str) { + ErrorObj* obj = static_cast(data_.get()); + obj->update_traceback(obj, traceback_str); + } + + const char* what() const noexcept(true) override { + thread_local std::string what_data; + ErrorObj* obj = static_cast(data_.get()); + what_data = (std::string("Traceback (most recent call last):\n") + + std::string(obj->traceback.data, obj->traceback.size) + + std::string(obj->kind.data, obj->kind.size) + std::string(": ") + + std::string(obj->message.data, obj->message.size) + '\n'); + return what_data.c_str(); + } + + TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Error, ObjectRef, ErrorObj); +}; + +namespace details { + +class ErrorBuilder { + public: + explicit ErrorBuilder(std::string kind, std::string traceback, bool log_before_throw) + : kind_(kind), traceback_(traceback), log_before_throw_(log_before_throw) {} + + explicit ErrorBuilder(std::string kind, const TVMFFIByteArray* traceback, bool log_before_throw) + : ErrorBuilder(kind, std::string(traceback->data, traceback->size), log_before_throw) {} + +// MSVC disable warning in error builder as it is exepected +#ifdef _MSC_VER +#pragma disagnostic push +#pragma warning(disable : 4722) +#endif + // avoid inline to reduce binary size, error throw path do not need to be fast + [[noreturn]] ~ErrorBuilder() noexcept(false) { + ::tvm::ffi::Error error(std::move(kind_), stream_.str(), std::move(traceback_)); + if (log_before_throw_) { + std::cerr << error.what(); + } + throw error; + } +#ifdef _MSC_VER +#pragma disagnostic pop +#endif + + std::ostringstream& stream() { return stream_; } + + protected: + std::string kind_; + std::ostringstream stream_; + std::string traceback_; + bool log_before_throw_; +}; + +// define traceback here as call into traceback function +#define TVM_FFI_TRACEBACK_HERE TVMFFITraceback(__FILE__, __LINE__, TVM_FFI_FUNC_SIG) +} // namespace details + +/*! + * \brief Helper macro to throw an error with traceback and message + * + * \code + * + * void ThrowError() { + * TVM_FFI_THROW(RuntimeError) << "error message"; + * } + * + * \endcode + */ +#define TVM_FFI_THROW(ErrorKind) \ + ::tvm::ffi::details::ErrorBuilder(#ErrorKind, TVM_FFI_TRACEBACK_HERE, false).stream() + +/*! + * \brief Explicitly log error in stderr and then throw the error. + * + * \note This is only necessary on startup functions where we know error + * cannot be caught, and it is better to have a clear log message. + * In most cases, we should use use TVM_FFI_THROW. + */ +#define TVM_FFI_LOG_AND_THROW(ErrorKind) \ + ::tvm::ffi::details::ErrorBuilder(#ErrorKind, TVM_FFI_TRACEBACK_HERE, true).stream() + +// Glog style checks with TVM_FFI prefix +// NOTE: we explicitly avoid glog style generic macros (LOG/CHECK) in tvm ffi +// to avoid potential conflict of downstream users who might have their own GLOG style macros +namespace details { + +template +TVM_FFI_INLINE std::unique_ptr LogCheckFormat(const X& x, const Y& y) { + std::ostringstream os; + os << " (" << x << " vs. " << y << ") "; // CHECK_XX(x, y) requires x and y can be serialized to + // string. Use CHECK(x OP y) otherwise. + return std::make_unique(os.str()); +} + +#define TVM_FFI_CHECK_FUNC(name, op) \ + template \ + TVM_FFI_INLINE std::unique_ptr LogCheck##name(const X& x, const Y& y) { \ + if (x op y) return nullptr; \ + return LogCheckFormat(x, y); \ + } \ + TVM_FFI_INLINE std::unique_ptr LogCheck##name(int x, int y) { \ + return LogCheck##name(x, y); \ + } + +// Inline _Pragma in macros does not work reliably on old version of MSVC and +// GCC. We wrap all comparisons in a function so that we can use #pragma to +// silence bad comparison warnings. +#if defined(__GNUC__) || defined(__clang__) // GCC and Clang +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wsign-compare" +#elif defined(_MSC_VER) // MSVC +#pragma warning(push) +#pragma warning(disable : 4389) // '==' : signed/unsigned mismatch +#endif + +TVM_FFI_CHECK_FUNC(_LT, <) +TVM_FFI_CHECK_FUNC(_GT, >) +TVM_FFI_CHECK_FUNC(_LE, <=) +TVM_FFI_CHECK_FUNC(_GE, >=) +TVM_FFI_CHECK_FUNC(_EQ, ==) +TVM_FFI_CHECK_FUNC(_NE, !=) + +#if defined(__GNUC__) || defined(__clang__) // GCC and Clang +#pragma GCC diagnostic pop +#elif defined(_MSC_VER) // MSVC +#pragma warning(pop) +#endif +} // namespace details + +#define TVM_FFI_ICHECK_BINARY_OP(name, op, x, y) \ + if (auto __tvm__log__err = ::tvm::ffi::details::LogCheck##name(x, y)) \ + TVM_FFI_THROW(InternalError) << "Check failed: " << #x " " #op " " #y << *__tvm__log__err << ": " + +#define TVM_FFI_ICHECK(x) \ + if (!(x)) TVM_FFI_THROW(InternalError) << "Check failed: (" #x << ") is false: " + +#define TVM_FFI_ICHECK_LT(x, y) TVM_FFI_ICHECK_BINARY_OP(_LT, <, x, y) +#define TVM_FFI_ICHECK_GT(x, y) TVM_FFI_ICHECK_BINARY_OP(_GT, >, x, y) +#define TVM_FFI_ICHECK_LE(x, y) TVM_FFI_ICHECK_BINARY_OP(_LE, <=, x, y) +#define TVM_FFI_ICHECK_GE(x, y) TVM_FFI_ICHECK_BINARY_OP(_GE, >=, x, y) +#define TVM_FFI_ICHECK_EQ(x, y) TVM_FFI_ICHECK_BINARY_OP(_EQ, ==, x, y) +#define TVM_FFI_ICHECK_NE(x, y) TVM_FFI_ICHECK_BINARY_OP(_NE, !=, x, y) +#define TVM_FFI_ICHECK_NOTNULL(x) \ + ((x) == nullptr ? TVM_FFI_THROW(InternalError) << "Check not null: " #x << ' ', \ + (x) : (x)) // NOLINT(*) +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_ERROR_H_ diff --git a/ffi/include/tvm/ffi/function.h b/ffi/include/tvm/ffi/function.h new file mode 100644 index 000000000000..708495b58199 --- /dev/null +++ b/ffi/include/tvm/ffi/function.h @@ -0,0 +1,914 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/ffi/function.h + * \brief A managed function in the TVM FFI. + */ +#ifndef TVM_FFI_FUNCTION_H_ +#define TVM_FFI_FUNCTION_H_ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace tvm { +namespace ffi { + +/** + * Helper macro to construct a safe call + * + * \brief Marks the begining of the safe call that catches exception explicitly + * + */ +#define TVM_FFI_SAFE_CALL_BEGIN() \ + try { \ + (void)0 + +/*! + * \brief Marks the end of safe call. + */ +#define TVM_FFI_SAFE_CALL_END() \ + return 0; \ + } \ + catch (const ::tvm::ffi::Error& err) { \ + ::tvm::ffi::details::SetSafeCallRaised(err); \ + return -1; \ + } \ + catch (const ::tvm::ffi::EnvErrorAlreadySet&) { \ + return -2; \ + } \ + catch (const std::exception& ex) { \ + ::tvm::ffi::details::SetSafeCallRaised(::tvm::ffi::Error("InternalError", ex.what(), "")); \ + return -1; \ + } \ + TVM_FFI_UNREACHABLE() + +#define TVM_FFI_CHECK_SAFE_CALL(func) \ + { \ + int ret_code = (func); \ + if (ret_code != 0) { \ + if (ret_code == -2) { \ + throw ::tvm::ffi::EnvErrorAlreadySet(); \ + } \ + throw ::tvm::ffi::details::MoveFromSafeCallRaised(); \ + } \ + } + +/*! + * \brief Object container class that backs ffi::Function + * \note Do not use this function directly, use ffi::Function + */ +class FunctionObj : public Object, public TVMFFIFunctionCell { + public: + typedef void (*FCall)(const FunctionObj*, const AnyView*, int32_t, Any*); + using TVMFFIFunctionCell::safe_call; + /*! \brief A C++ style call implementation, with exception propagation in c++ style. */ + FCall call; + + TVM_FFI_INLINE void CallPacked(const AnyView* args, int32_t num_args, Any* result) const { + this->call(this, args, num_args, result); + } + + static constexpr const uint32_t _type_index = TypeIndex::kTVMFFIFunction; + static constexpr const char* _type_key = "object.Function"; + + TVM_FFI_DECLARE_STATIC_OBJECT_INFO(FunctionObj, Object); + + protected: + /*! \brief Make default constructor protected. */ + FunctionObj() {} + + // Implementing safe call style + static int SafeCall(void* func, const TVMFFIAny* args, int32_t num_args, TVMFFIAny* result) { + TVM_FFI_SAFE_CALL_BEGIN(); + TVM_FFI_ICHECK_LT(result->type_index, TypeIndex::kTVMFFIStaticObjectBegin); + FunctionObj* self = static_cast(func); + self->call(self, reinterpret_cast(args), num_args, + reinterpret_cast(result)); + TVM_FFI_SAFE_CALL_END(); + } + + friend class Function; +}; + +namespace details { +/*! + * \brief Derived object class for constructing FunctionObj backed by a TCallable + * + * This is a helper class that + */ +template +class FunctionObjImpl : public FunctionObj { + public: + using TStorage = typename std::remove_cv::type>::type; + /*! \brief The type of derived object class */ + using TSelf = FunctionObjImpl; + /*! + * \brief Derived object class for constructing PackedFuncObj. + * \param callable The type-erased callable object. + */ + explicit FunctionObjImpl(TCallable callable) : callable_(callable) { + this->safe_call = SafeCall; + this->call = Call; + } + + private: + // implementation of call + static void Call(const FunctionObj* func, const AnyView* args, int32_t num_args, Any* result) { + (static_cast(func))->callable_(args, num_args, result); + } + + /*! \brief Type-erased filed for storing callable object*/ + mutable TStorage callable_; +}; + +/*! + * \brief Base class to provide a common implementation to redirect call to safecall + * \tparam Derived The derived class in CRTP-idiom + */ +template +struct RedirectCallToSafeCall { + static void Call(const FunctionObj* func, const AnyView* args, int32_t num_args, Any* rv) { + Derived* self = static_cast(const_cast(func)); + TVM_FFI_CHECK_SAFE_CALL(self->RedirectSafeCall(reinterpret_cast(args), + num_args, reinterpret_cast(rv))); + } + + static int32_t SafeCall(void* func, const TVMFFIAny* args, int32_t num_args, TVMFFIAny* rv) { + Derived* self = reinterpret_cast(func); + return self->RedirectSafeCall(args, num_args, rv); + } +}; + +/*! + * \brief FunctionObj specialization that leverages C-style callback definitions. + */ +class ExternCFunctionObjImpl : public FunctionObj, + public RedirectCallToSafeCall { + public: + using RedirectCallToSafeCall::SafeCall; + + ExternCFunctionObjImpl(void* self, TVMFFISafeCallType safe_call, void (*deleter)(void* self)) + : self_(self), safe_call_(safe_call), deleter_(deleter) { + this->call = RedirectCallToSafeCall::Call; + this->safe_call = RedirectCallToSafeCall::SafeCall; + } + + ~ExternCFunctionObjImpl() { deleter_(self_); } + + TVM_FFI_INLINE int32_t RedirectSafeCall(const TVMFFIAny* args, int32_t num_args, + TVMFFIAny* rv) const { + return safe_call_(self_, args, num_args, rv); + } + + private: + void* self_; + TVMFFISafeCallType safe_call_; + void (*deleter_)(void* self); +}; + +/*! + * \brief FunctionObj specialization that wraps an external function. + */ +class ImportedFunctionObjImpl : public FunctionObj, + public RedirectCallToSafeCall { + public: + using RedirectCallToSafeCall::SafeCall; + + explicit ImportedFunctionObjImpl(ObjectPtr data) : data_(data) { + this->call = RedirectCallToSafeCall::Call; + this->safe_call = RedirectCallToSafeCall::SafeCall; + } + + TVM_FFI_INLINE int32_t RedirectSafeCall(const TVMFFIAny* args, int32_t num_args, + TVMFFIAny* rv) const { + FunctionObj* func = const_cast(static_cast(data_.get())); + return func->safe_call(func, args, num_args, rv); + } + + private: + ObjectPtr data_; +}; + +// Helper class to set packed arguments +class PackedArgsSetter { + public: + explicit PackedArgsSetter(AnyView* args) : args_(args) {} + + // NOTE: setter needs to be very carefully designed + // such that we do not have temp variable conversion(eg. convert from lvalue to rvalue) + // that is why we need T&& and std::forward here + template + TVM_FFI_INLINE void operator()(size_t i, T&& value) const { + args_[i].operator=(std::forward(value)); + } + + private: + AnyView* args_; +}; +} // namespace details + +/*! + * \brief Represents arguments packed in AnyView array + * \note This class represent packed arguments to ffi::Function + */ +class PackedArgs { + public: + /*! + * \brief Constructor + * \param data The arguments + * \param size The number of arguments + */ + PackedArgs(const AnyView* data, int32_t size) : data_(data), size_(size) {} + + /*! \return size of the arguments */ + int size() const { return size_; } + + /*! \return The arguments */ + const AnyView* data() const { return data_; } + + /*! + * \brief Slice the arguments + * \param begin The begin index + * \param end The end index + * \return The sliced arguments + */ + PackedArgs Slice(int begin, int end = -1) const { + if (end == -1) { + end = size_; + } + return PackedArgs(data_ + begin, end - begin); + } + + /*! + * \brief Get i-th argument + * \param i the index. + * \return the ith argument. + */ + AnyView operator[](int i) const { return data_[i]; } + + /*! + * \brief Fill the arguments into the AnyView array + * \param data The AnyView array to store the packed arguments + * \param args The arguments to be packed + * \note Caller must ensure all args are alive during lifetime of data. + * A common pitfall is to pass in local variables that are immediately + * destroyed after calling Fill. + */ + template + static void TVM_FFI_INLINE Fill(AnyView* data, Args&&... args) { + details::for_each(details::PackedArgsSetter(data), std::forward(args)...); + } + + private: + /*! \brief The arguments */ + const AnyView* data_; + /*! \brief The number of arguments */ + int32_t size_; +}; + +/*! + * \brief ffi::Function is a type-erased function. + * The arguments are passed by packed format. + */ +class Function : public ObjectRef { + public: + /*! \brief Constructor from null */ + Function(std::nullptr_t) : ObjectRef(nullptr) {} // NOLINT(*) + /*! + * \brief Constructing a packed function from a callable type + * whose signature is consistent with `PackedFunc` + * \param packed_call The packed function signature + * \note legacy purpose, should change to Function::FromPacked for mostfuture use. + */ + template + explicit Function(TCallable packed_call) { + *this = FromPacked(packed_call); + } + /*! + * \brief Constructing a packed function from a callable type + * whose signature is consistent with `PackedFunc` + * \param packed_call The packed function signature + */ + template + static Function FromPacked(TCallable packed_call) { + static_assert( + std::is_convertible_v> || + std::is_convertible_v>, + "tvm::ffi::Function::FromPacked requires input function signature to match packed func " + "format"); + if constexpr (std::is_convertible_v>) { + auto wrapped_call = [packed_call](const AnyView* args, int32_t num_args, + Any* rv) mutable -> void { + PackedArgs args_pack(args, num_args); + packed_call(args_pack, rv); + }; + return FromPackedInternal(wrapped_call); + } else { + return FromPackedInternal(packed_call); + } + } + /*! + * \brief Import a possibly externally defined function to this dll + * \param other Function defined in another dynamic library. + * + * \note This function will redirect the call to safe_call in other. + * It will try to detect if the function is already from the same DLL + * and directly return the original function if so. + * + * \return The imported function. + */ + static Function ImportFromExternDLL(Function other) { + const FunctionObj* other_func = static_cast(other.get()); + // the other function comes from the same dll, no action needed + if (other_func->safe_call == &(FunctionObj::SafeCall) || + other_func->safe_call == &(details::ImportedFunctionObjImpl::SafeCall) || + other_func->safe_call == &(details::ExternCFunctionObjImpl::SafeCall)) { + return other; + } + // the other function coems from a different library + Function func; + func.data_ = make_object(std::move(other.data_)); + return func; + } + /*! + * \brief Create ffi::Function from a C style callbacks. + * \param self Resource handle to the function + * \param safe_call The safe_call definition in C. + * \param deleter The deleter to release the resource of self. + * \return The created function. + */ + static Function FromExternC(void* self, TVMFFISafeCallType safe_call, + void (*deleter)(void* self)) { + // the other function coems from a different library + Function func; + func.data_ = make_object(self, safe_call, deleter); + return func; + } + /*! + * \brief Get global function by name + * \param name The function name + * \return The global function. + * \note This function will return std::nullopt if the function is not found. + */ + static std::optional GetGlobal(std::string_view name) { + TVMFFIObjectHandle handle; + TVMFFIByteArray name_arr{name.data(), name.size()}; + TVM_FFI_CHECK_SAFE_CALL(TVMFFIFunctionGetGlobal(&name_arr, &handle)); + if (handle != nullptr) { + return Function( + details::ObjectUnsafe::ObjectPtrFromOwned(static_cast(handle))); + } else { + return std::nullopt; + } + } + + static std::optional GetGlobal(const std::string& name) { + return GetGlobal(std::string_view(name.data(), name.length())); + } + + static std::optional GetGlobal(const String& name) { + return GetGlobal(std::string_view(name.data(), name.length())); + } + + static std::optional GetGlobal(const char* name) { + return GetGlobal(std::string_view(name)); + } + /*! + * \brief Get global function by name and throw an error if it is not found. + * \param name The name of the function + * \return The global function + * \note This function will throw an error if the function is not found. + */ + static Function GetGlobalRequired(std::string_view name) { + std::optional res = GetGlobal(name); + if (!res.has_value()) { + TVM_FFI_THROW(ValueError) << "Function " << name << " not found"; + } + return res.value(); + } + + static Function GetGlobalRequired(const std::string& name) { + return GetGlobalRequired(std::string_view(name.data(), name.length())); + } + + static Function GetGlobalRequired(const String& name) { + return GetGlobalRequired(std::string_view(name.data(), name.length())); + } + + static Function GetGlobalRequired(const char* name) { + return GetGlobalRequired(std::string_view(name)); + } + /*! + * \brief Set global function by name + * \param name The name of the function + * \param func The function + * \param override Whether to override when there is duplication. + */ + static void SetGlobal(std::string_view name, Function func, bool override = false) { + TVMFFIByteArray name_arr{name.data(), name.size()}; + TVM_FFI_CHECK_SAFE_CALL( + TVMFFIFunctionSetGlobal(&name_arr, details::ObjectUnsafe::GetHeader(func.get()), override)); + } + /*! + * \brief List all global names + * \return A vector of all global names + * \note This function do not depend on Array so core do not have container dep. + */ + static std::vector ListGlobalNames() { + Function fname_functor = + GetGlobalRequired("ffi.FunctionListGlobalNamesFunctor")().cast(); + std::vector names; + int len = fname_functor(-1).cast(); + for (int i = 0; i < len; ++i) { + names.push_back(fname_functor(i).cast()); + } + return names; + } + /** + * \brief Remove a global function by name + * \param name The name of the function + */ + static void RemoveGlobal(const String& name) { + static Function fremove = GetGlobalRequired("ffi.FunctionRemoveGlobal"); + fremove(name); + } + /*! + * \brief Constructing a packed function from a normal function. + * + * \param callable the internal container of packed function. + */ + template + static Function FromUnpacked(TCallable callable) { + using FuncInfo = details::FunctionInfo; + auto call_packed = [callable](const AnyView* args, int32_t num_args, Any* rv) mutable -> void { + details::unpack_call( + std::make_index_sequence{}, nullptr, callable, args, num_args, rv); + }; + return FromPackedInternal(call_packed); + } + /*! + * \brief Constructing a packed function from a normal function. + * + * \param callable the internal container of packed function. + * \param name optional name attacked to the function. + */ + template + static Function FromUnpacked(TCallable callable, std::string name) { + using FuncInfo = details::FunctionInfo; + auto call_packed = [callable, name](const AnyView* args, int32_t num_args, + Any* rv) mutable -> void { + details::unpack_call( + std::make_index_sequence{}, &name, callable, args, num_args, rv); + }; + return FromPackedInternal(call_packed); + } + /*! + * \brief Call function by directly passing in unpacked arguments. + * + * \param args Arguments to be passed. + * \tparam Args arguments to be passed. + * + * \code + * // Example code on how to call packed function + * void CallFFIFunction(tvm::ffi::Function f) { + * // call like normal functions by pass in arguments + * // return value is automatically converted back + * int rvalue = f(1, 2.0); + * } + * \endcode + */ + template + TVM_FFI_INLINE Any operator()(Args&&... args) const { + const int kNumArgs = sizeof...(Args); + const int kArraySize = kNumArgs > 0 ? kNumArgs : 1; + AnyView args_pack[kArraySize]; + PackedArgs::Fill(args_pack, std::forward(args)...); + Any result; + static_cast(data_.get())->CallPacked(args_pack, kNumArgs, &result); + return result; + } + /*! + * \brief Call the function in packed format. + * \param args The arguments + * \param rv The return value. + */ + TVM_FFI_INLINE void CallPacked(const AnyView* args, int32_t num_args, Any* result) const { + static_cast(data_.get())->CallPacked(args, num_args, result); + } + /*! + * \brief Call the function in packed format. + * \param args The arguments + * \param result The return value. + */ + TVM_FFI_INLINE void CallPacked(PackedArgs args, Any* result) const { + static_cast(data_.get())->CallPacked(args.data(), args.size(), result); + } + + /*! \return Whether the packed function is nullptr */ + TVM_FFI_INLINE bool operator==(std::nullptr_t) const { return data_ == nullptr; } + /*! \return Whether the packed function is not nullptr */ + TVM_FFI_INLINE bool operator!=(std::nullptr_t) const { return data_ != nullptr; } + + TVM_FFI_DEFINE_OBJECT_REF_METHODS(Function, ObjectRef, FunctionObj); + + class Registry; + + private: + /*! + * \brief Constructing a packed function from a callable type + * whose signature is consistent with `PackedFunc` + * \param packed_call The packed function signature + */ + template + static Function FromPackedInternal(TCallable packed_call) { + using ObjType = typename details::FunctionObjImpl; + Function func; + func.data_ = make_object(std::forward(packed_call)); + return func; + } +}; + +/*! + * \brief Please refer to \ref TypedFunctionAnchor "TypedFunction" + */ +template +class TypedFunction; + +/*! + * \anchor TypedFunctionAnchor + * \brief A PackedFunc wrapper to provide typed function signature. + * It is backed by a PackedFunc internally. + * + * TypedFunction enables compile time type checking. + * TypedFunction works with the runtime system: + * - It can be passed as an argument of PackedFunc. + * - It can be assigned to TVMRetValue. + * - It can be directly converted to a type-erased PackedFunc. + * + * Developers should prefer TypedFunction over PackedFunc in C++ code + * as it enables compile time checking. + * We can construct a TypedFunction from a lambda function + * with the same signature. + * + * \code + * // user defined lambda function. + * auto addone = [](int x)->int { + * return x + 1; + * }; + * // We can directly convert + * // lambda function to TypedFunction + * TypedFunction ftyped(addone); + * // invoke the function. + * int y = ftyped(1); + * // Can be directly converted to PackedFunc + * PackedFunc packed = ftype; + * \endcode + * \tparam R The return value of the function. + * \tparam Args The argument signature of the function. + */ +template +class TypedFunction { + public: + /*! \brief short hand for this function type */ + using TSelf = TypedFunction; + /*! \brief default constructor */ + TypedFunction() {} + /*! \brief constructor from null */ + TypedFunction(std::nullptr_t null) {} // NOLINT(*) + /*! + * \brief constructor from a function + * \param packed The function + */ + TypedFunction(Function packed) : packed_(packed) {} // NOLINT(*) + /*! + * \brief construct from a lambda function with the same signature. + * + * Example usage: + * \code + * auto typed_lambda = [](int x)->int { return x + 1; } + * // construct from packed function + * TypedFunction ftyped(typed_lambda, "add_one"); + * // call the typed version. + * CHECK_EQ(ftyped(1), 2); + * \endcode + * + * \param typed_lambda typed lambda function. + * \param name the name of the lambda function. + * \tparam FLambda the type of the lambda function. + */ + template >::value>::type> + TypedFunction(FLambda typed_lambda, std::string name) { // NOLINT(*) + packed_ = Function::FromUnpacked(typed_lambda, name); + } + /*! + * \brief construct from a lambda function with the same signature. + * + * This version does not take a name. It is highly recommend you use the + * version that takes a name for the lambda. + * + * Example usage: + * \code + * auto typed_lambda = [](int x)->int { return x + 1; } + * // construct from packed function + * TypedFunction ftyped(typed_lambda); + * // call the typed version. + * CHECK_EQ(ftyped(1), 2); + * \endcode + * + * \param typed_lambda typed lambda function. + * \tparam FLambda the type of the lambda function. + */ + template >::value>::type> + TypedFunction(const FLambda& typed_lambda) { // NOLINT(*) + packed_ = Function::FromUnpacked(typed_lambda); + } + /*! + * \brief copy assignment operator from typed lambda + * + * Example usage: + * \code + * // construct from packed function + * TypedFunction ftyped; + * ftyped = [](int x) { return x + 1; } + * // call the typed version. + * CHECK_EQ(ftyped(1), 2); + * \endcode + * + * \param typed_lambda typed lambda function. + * \tparam FLambda the type of the lambda function. + * \returns reference to self. + */ + template >::value>::type> + TSelf& operator=(FLambda typed_lambda) { // NOLINT(*) + packed_ = Function::FromUnpacked(typed_lambda); + return *this; + } + /*! + * \brief copy assignment operator from PackedFunc. + * \param packed The packed function. + * \returns reference to self. + */ + TSelf& operator=(Function packed) { + packed_ = std::move(packed); + return *this; + } + /*! + * \brief Invoke the operator. + * \param args The arguments + * \returns The return value. + */ + TVM_FFI_INLINE R operator()(Args... args) const { + if constexpr (std::is_same_v) { + packed_(std::forward(args)...); + } else { + Any res = packed_(std::forward(args)...); + if constexpr (std::is_same_v) { + return res; + } else { + return std::move(res).cast(); + } + } + } + /*! + * \brief convert to PackedFunc + * \return the internal PackedFunc + */ + operator Function() const { return packed(); } + /*! + * \return reference the internal PackedFunc + */ + const Function& packed() const& { return packed_; } + /*! + * \return r-value reference the internal PackedFunc + */ + constexpr Function&& packed() && { return std::move(packed_); } + /*! \return Whether the packed function is nullptr */ + bool operator==(std::nullptr_t null) const { return packed_ == nullptr; } + /*! \return Whether the packed function is not nullptr */ + bool operator!=(std::nullptr_t null) const { return packed_ != nullptr; } + + private: + /*! \brief The internal packed function */ + Function packed_; +}; + +template +inline constexpr bool use_default_type_traits_v> = false; + +template +struct TypeTraits> : public TypeTraitsBase { + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIFunction; + + static TVM_FFI_INLINE void CopyToAnyView(const TypedFunction& src, TVMFFIAny* result) { + TypeTraits::CopyToAnyView(src.packed(), result); + } + + static TVM_FFI_INLINE void MoveToAny(TypedFunction src, TVMFFIAny* result) { + TypeTraits::MoveToAny(std::move(src.packed()), result); + } + + static TVM_FFI_INLINE bool CheckAnyStorage(const TVMFFIAny* src) { + return src->type_index == TypeIndex::kTVMFFIFunction; + } + + static TVM_FFI_INLINE TypedFunction CopyFromAnyStorageAfterCheck(const TVMFFIAny* src) { + return TypedFunction(TypeTraits::CopyFromAnyStorageAfterCheck(src)); + } + + static TVM_FFI_INLINE std::optional> TryConvertFromAnyView( + const TVMFFIAny* src) { + std::optional opt = TypeTraits::TryConvertFromAnyView(src); + if (opt.has_value()) { + return TypedFunction(*std::move(opt)); + } else { + return std::nullopt; + } + } + + static TVM_FFI_INLINE std::string TypeStr() { return details::FunctionInfo::Sig(); } +}; + +/*! \brief Registry for global function */ +class Function::Registry { + public: + /*! \brief constructor */ + explicit Registry(const char* name) : name_(name) {} + + /*! + * \brief Set body to be to use the packed convention. + * + * \tparam FLambda The signature of the function. + * \param f The body of the function. + */ + template + Registry& set_body_packed(FLambda f) { + return Register(ffi::Function::FromPacked(f)); + } + /*! + * \brief set the body of the function to the given function. + * Note that this will ignore default arg values and always require all arguments to be + * provided. + * + * \code + * + * int multiply(int x, int y) { + * return x * y; + * } + * + * TVM_FFI_REGISTER_GLOBAL("multiply") + * .set_body_typed(multiply); // will have type int(int, int) + * + * // will have type int(int, int) + * TVM_REGISTER_GLOBAL("sub") + * .set_body_typed([](int a, int b) -> int { return a - b; }); + * + * \endcode + * + * \param f The function to forward to. + * \tparam FLambda The signature of the function. + */ + template + Registry& set_body_typed(FLambda f) { + return Register(Function::FromUnpacked(f, name_)); + } + + /*! + * \brief set the body of the function to be the passed method pointer. + * Note that this will ignore default arg values and always require all arguments to be + * provided. + * + * \code + * + * // objectRef subclass: + * struct Example : ObjectRef { + * int DoThing(int x); + * } + * TVM_FFI_REGISTER_GLOBAL("Example_DoThing") + * .set_body_method(&Example::DoThing); // will have type int(self, int) + * + * // Object subclass: + * struct Example : Object { + * int DoThing(int x); + * } + * + * TVM_FFI_REGISTER_GLOBAL("Example_DoThing") + * .set_body_method(&Example::DoThing); // will have type int(self, int) + * + * \endcode + * + * \param f the method pointer to forward to. + * \tparam T the type containing the method (inferred). + * \tparam R the return type of the function (inferred). + * \tparam Args the argument types of the function (inferred). + */ + template + Registry& set_body_method(R (T::*f)(Args...)) { + static_assert(std::is_base_of_v || std::is_base_of_v, + "T must be derived from ObjectRef or Object"); + if constexpr (std::is_base_of_v) { + auto fwrap = [f](T target, Args... params) -> R { + // call method pointer + return (target.*f)(params...); + }; + return Register(ffi::Function::FromUnpacked(fwrap, name_)); + } + if constexpr (std::is_base_of_v) { + auto fwrap = [f](const T* target, Args... params) -> R { + // call method pointer + return (const_cast(target)->*f)(params...); + }; + return Register(ffi::Function::FromUnpacked(fwrap, name_)); + } + return *this; + } + + template + Registry& set_body_method(R (T::*f)(Args...) const) { + static_assert(std::is_base_of_v || std::is_base_of_v, + "T must be derived from ObjectRef or Object"); + if constexpr (std::is_base_of_v) { + auto fwrap = [f](const T target, Args... params) -> R { + // call method pointer + return (target.*f)(params...); + }; + return Register(ffi::Function::FromUnpacked(fwrap, name_)); + } + if constexpr (std::is_base_of_v) { + auto fwrap = [f](const T* target, Args... params) -> R { + // call method pointer + return (target->*f)(params...); + }; + return Register(ffi::Function::FromUnpacked(fwrap, name_)); + } + return *this; + } + + protected: + /*! + * \brief set the body of the function to be f + * \param f The body of the function. + */ + Registry& Register(Function f) { + Function::SetGlobal(name_, f); + return *this; + } + + /*! \brief name of the function */ + const char* name_; +}; + +/*! + * \brief helper function to get type index from key + */ +inline int32_t TypeKeyToIndex(std::string_view type_key) { + int32_t type_index; + TVMFFIByteArray type_key_array = {type_key.data(), type_key.size()}; + TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_array, &type_index)); + return type_index; +} + +#define TVM_FFI_FUNC_REG_VAR_DEF \ + static inline TVM_FFI_ATTRIBUTE_UNUSED ::tvm::ffi::Function::Registry& __##TVMFFIFuncReg + +/*! + * \brief Register a function globally. + * \code + * TVM_FFI_REGISTER_GLOBAL("MyAdd") + * .set_body_typed([](int a, int b) { + * return a + b; + * }); + * \endcode + */ +#define TVM_FFI_REGISTER_GLOBAL(OpName) \ + TVM_FFI_STR_CONCAT(TVM_FFI_FUNC_REG_VAR_DEF, __COUNTER__) = ::tvm::ffi::Function::Registry(OpName) +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_FUNCTION_H_ diff --git a/ffi/include/tvm/ffi/function_details.h b/ffi/include/tvm/ffi/function_details.h new file mode 100644 index 000000000000..f47a253a5872 --- /dev/null +++ b/ffi/include/tvm/ffi/function_details.h @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/ffi/function_details.h + * \brief Implements the funciton signature reflection + */ +#ifndef TVM_FFI_FUNCTION_DETAILS_H_ +#define TVM_FFI_FUNCTION_DETAILS_H_ + +#include +#include +#include +#include + +#include +#include +#include + +namespace tvm { +namespace ffi { +namespace details { + +template +struct Arg2Str { + template + static TVM_FFI_INLINE void Apply(std::ostream& os) { + using Arg = std::tuple_element_t; + if constexpr (i != 0) { + os << ", "; + } + os << i << ": " << Type2Str::v(); + } + template + static TVM_FFI_INLINE void Run(std::ostream& os, std::index_sequence) { + using TExpander = int[]; + (void)TExpander{0, (Apply(os), 0)...}; + } +}; + +template +static constexpr bool ArgSupported = + (std::is_same_v>, Any> || + std::is_same_v>, AnyView> || + TypeTraitsNoCR::convert_enabled); + +// NOTE: return type can only support non-reference managed returns +template +static constexpr bool RetSupported = + (std::is_same_v || std::is_void_v || TypeTraits::convert_enabled); + +template +struct FuncFunctorImpl { + using FType = R(Args...); + using ArgType = std::tuple; + using RetType = R; + /*! \brief total number of arguments*/ + static constexpr size_t num_args = sizeof...(Args); + // MSVC is not that friendly to in-template nested bool evaluation +#ifndef _MSC_VER + /*! \brief Whether this function can be converted to ffi::Function via FromUnpacked */ + static constexpr bool unpacked_supported = (ArgSupported && ...) && (RetSupported); +#endif + + static TVM_FFI_INLINE std::string Sig() { + using IdxSeq = std::make_index_sequence; + std::ostringstream ss; + ss << "("; + Arg2Str>::Run(ss, IdxSeq{}); + ss << ") -> " << Type2Str::v(); + return ss.str(); + } +}; + +template +struct FunctionInfoHelper; + +template +struct FunctionInfoHelper : FuncFunctorImpl {}; +template +struct FunctionInfoHelper : FuncFunctorImpl {}; + +/*! + * \brief Template class to get function signature of a function or functor. + * \tparam T The function/functor type. + * \note We need a decltype redirection because this helps lambda types. + */ +template +struct FunctionInfo : FunctionInfoHelper {}; + +template +struct FunctionInfo : FuncFunctorImpl {}; +template +struct FunctionInfo : FuncFunctorImpl {}; + +/*! \brief Using static function to output TypedPackedFunc signature */ +typedef std::string (*FGetFuncSignature)(); + +/*! + * \brief Auxilary argument value with context for error reporting + */ +class ArgValueWithContext { + public: + /*! + * \brief move constructor from another return value. + * \param args The argument list + * \param arg_index In a function call, this argument is at index arg_index (0-indexed). + * \param optional_name Name of the function being called. Can be nullptr if the function is not. + * \param f_sig Pointer to static function outputting signature of the function being called. + * named. + */ + TVM_FFI_INLINE ArgValueWithContext(const AnyView* args, int32_t arg_index, + const std::string* optional_name, FGetFuncSignature f_sig) + : args_(args), arg_index_(arg_index), optional_name_(optional_name), f_sig_(f_sig) {} + + template + TVM_FFI_INLINE operator Type() { + using TypeWithoutCR = std::remove_const_t>; + + if constexpr (std::is_same_v) { + return args_[arg_index_]; + } else if constexpr (std::is_same_v) { + return Any(args_[arg_index_]); + } else { + std::optional opt = args_[arg_index_].as(); + if (!opt.has_value()) { + TVMFFIAny any_data = args_[arg_index_].CopyToTVMFFIAny(); + TVM_FFI_THROW(TypeError) << "Mismatched type on argument #" << arg_index_ + << " when calling: `" + << (optional_name_ == nullptr ? "" : *optional_name_) + << (f_sig_ == nullptr ? "" : (*f_sig_)()) << "`. Expected `" + << Type2Str::v() << "` but got `" + << TypeTraits::GetMismatchTypeInfo(&any_data) + << '`'; + } + return *std::move(opt); + } + } + + private: + const AnyView* args_; + int32_t arg_index_; + const std::string* optional_name_; + FGetFuncSignature f_sig_; +}; + +template +TVM_FFI_INLINE void unpack_call(std::index_sequence, const std::string* optional_name, + const F& f, [[maybe_unused]] const AnyView* args, + [[maybe_unused]] int32_t num_args, [[maybe_unused]] Any* rv) { + using FuncInfo = FunctionInfo; + FGetFuncSignature f_sig = FuncInfo::Sig; + + // somehow MSVC does not support the static constexpr member in this case, function is fine +#ifndef _MSC_VER + static_assert(FuncInfo::unpacked_supported, "The function signature do not support unpacked"); +#endif + constexpr size_t nargs = sizeof...(Is); + if (nargs != num_args) { + TVM_FFI_THROW(TypeError) << "Mismatched number of arguments when calling: `" + << (optional_name == nullptr ? "" : *optional_name) + << (f_sig == nullptr ? "" : (*f_sig)()) << "`. Expected " << nargs + << " but got " << num_args << " arguments"; + } + // use index sequence to do recursive-less unpacking + if constexpr (std::is_same_v) { + f(ArgValueWithContext(args, Is, optional_name, f_sig)...); + } else { + *rv = R(f(ArgValueWithContext(args, Is, optional_name, f_sig)...)); + } +} + +/*! + * \brief Move the safe call raised error to the caller + * \return The error + */ +TVM_FFI_INLINE static Error MoveFromSafeCallRaised() { + TVMFFIObjectHandle handle; + TVMFFIErrorMoveFromRaised(&handle); + // handle is owned by caller + return Error( + details::ObjectUnsafe::ObjectPtrFromOwned(static_cast(handle))); +} + +/*! + * \brief Set the safe call raised error + * \param error The error + */ +TVM_FFI_INLINE static void SetSafeCallRaised(const Error& error) { + TVMFFIErrorSetRaised(details::ObjectUnsafe::TVMFFIObjectPtrFromObjectRef(error)); +} +} // namespace details +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_FUNCTION_DETAILS_H_ diff --git a/ffi/include/tvm/ffi/memory.h b/ffi/include/tvm/ffi/memory.h new file mode 100644 index 000000000000..d3af2dd49077 --- /dev/null +++ b/ffi/include/tvm/ffi/memory.h @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/ffi/memory.h + * \brief Runtime memory management to allocate on heap object. + */ +#ifndef TVM_FFI_MEMORY_H_ +#define TVM_FFI_MEMORY_H_ + +#include + +#include +#include +#include + +namespace tvm { +namespace ffi { + +/*! \brief Deleter function for obeject */ +typedef void (*FObjectDeleter)(TVMFFIObject* obj); + +/*! + * \brief Allocate an object using default allocator. + * \param args arguments to the constructor. + * \tparam T the node type. + * \return The ObjectPtr to the allocated object. + */ +template +inline ObjectPtr make_object(Args&&... args); + +// Detail implementations after this +// +// The current design allows swapping the +// allocator pattern when necessary. +// +// Possible future allocator optimizations: +// - Arena allocator that gives ownership of memory to arena (deleter = nullptr) +// - Thread-local object pools: one pool per size and alignment requirement. +// - Can specialize by type of object to give the specific allocator to each object. + +/*! + * \brief Base class of object allocators that implements make. + * Use curiously recurring template pattern. + * + * \tparam Derived The derived class. + */ +template +class ObjAllocatorBase { + public: + /*! + * \brief Make a new object using the allocator. + * \tparam T The type to be allocated. + * \tparam Args The constructor signature. + * \param args The arguments. + */ + template + inline ObjectPtr make_object(Args&&... args) { + using Handler = typename Derived::template Handler; + static_assert(std::is_base_of::value, "make can only be used to create Object"); + T* ptr = Handler::New(static_cast(this), std::forward(args)...); + TVMFFIObject* ffi_ptr = details::ObjectUnsafe::GetHeader(ptr); + ffi_ptr->ref_counter = 1; + ffi_ptr->type_index = T::RuntimeTypeIndex(); + ffi_ptr->deleter = Handler::Deleter(); + return details::ObjectUnsafe::ObjectPtrFromOwned(ptr); + } + + /*! + * \tparam ArrayType The type to be allocated. + * \tparam ElemType The type of array element. + * \tparam Args The constructor signature. + * \param num_elems The number of array elements. + * \param args The arguments. + */ + template + inline ObjectPtr make_inplace_array(size_t num_elems, Args&&... args) { + using Handler = typename Derived::template ArrayHandler; + static_assert(std::is_base_of::value, + "make_inplace_array can only be used to create Object"); + ArrayType* ptr = + Handler::New(static_cast(this), num_elems, std::forward(args)...); + TVMFFIObject* ffi_ptr = details::ObjectUnsafe::GetHeader(ptr); + ffi_ptr->ref_counter = 1; + ffi_ptr->type_index = ArrayType::RuntimeTypeIndex(); + ffi_ptr->deleter = Handler::Deleter(); + return details::ObjectUnsafe::ObjectPtrFromOwned(ptr); + } +}; + +// Simple allocator that uses new/delete. +class SimpleObjAllocator : public ObjAllocatorBase { + public: + template + class Handler { + public: + using StorageType = typename std::aligned_storage::type; + + template + static T* New(SimpleObjAllocator*, Args&&... args) { + // NOTE: the first argument is not needed for SimpleObjAllocator + // It is reserved for special allocators that needs to recycle + // the object to itself (e.g. in the case of object pool). + // + // In the case of an object pool, an allocator needs to create + // a special chunk memory that hides reference to the allocator + // and call allocator's release function in the deleter. + + // NOTE2: Use inplace new to allocate + // This is used to get rid of warning when deleting a virtual + // class with non-virtual destructor. + // We are fine here as we captured the right deleter during construction. + // This is also the right way to get storage type for an object pool. + StorageType* data = new StorageType(); + new (data) T(std::forward(args)...); + return reinterpret_cast(data); + } + + static FObjectDeleter Deleter() { return Deleter_; } + + private: + static void Deleter_(TVMFFIObject* objptr) { + T* tptr = details::ObjectUnsafe::RawObjectPtrFromUnowned(objptr); + // It is important to do tptr->T::~T(), + // so that we explicitly call the specific destructor + // instead of tptr->~T(), which could mean the intention + // call a virtual destructor(which may not be available and is not required). + tptr->T::~T(); + delete reinterpret_cast(tptr); + } + }; + + // Array handler that uses new/delete. + template + class ArrayHandler { + public: + using StorageType = typename std::aligned_storage::type; + // for now only support elements that aligns with array header. + static_assert(alignof(ArrayType) % alignof(ElemType) == 0 && + sizeof(ArrayType) % alignof(ElemType) == 0, + "element alignment constraint"); + + template + static ArrayType* New(SimpleObjAllocator*, size_t num_elems, Args&&... args) { + // NOTE: the first argument is not needed for ArrayObjAllocator + // It is reserved for special allocators that needs to recycle + // the object to itself (e.g. in the case of object pool). + // + // In the case of an object pool, an allocator needs to create + // a special chunk memory that hides reference to the allocator + // and call allocator's release function in the deleter. + // NOTE2: Use inplace new to allocate + // This is used to get rid of warning when deleting a virtual + // class with non-virtual destructor. + // We are fine here as we captured the right deleter during construction. + // This is also the right way to get storage type for an object pool. + size_t unit = sizeof(StorageType); + size_t requested_size = num_elems * sizeof(ElemType) + sizeof(ArrayType); + size_t num_storage_slots = (requested_size + unit - 1) / unit; + StorageType* data = new StorageType[num_storage_slots]; + new (data) ArrayType(std::forward(args)...); + return reinterpret_cast(data); + } + + static FObjectDeleter Deleter() { return Deleter_; } + + private: + static void Deleter_(TVMFFIObject* objptr) { + ArrayType* tptr = details::ObjectUnsafe::RawObjectPtrFromUnowned(objptr); + // It is important to do tptr->ArrayType::~ArrayType(), + // so that we explicitly call the specific destructor + // instead of tptr->~ArrayType(), which could mean the intention + // call a virtual destructor(which may not be available and is not required). + tptr->ArrayType::~ArrayType(); + StorageType* p = reinterpret_cast(tptr); + delete[] p; + } + }; +}; + +template +inline ObjectPtr make_object(Args&&... args) { + return SimpleObjAllocator().make_object(std::forward(args)...); +} + +template +inline ObjectPtr make_inplace_array_object(size_t num_elems, Args&&... args) { + return SimpleObjAllocator().make_inplace_array(num_elems, + std::forward(args)...); +} + +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_MEMORY_H_ diff --git a/ffi/include/tvm/ffi/object.h b/ffi/include/tvm/ffi/object.h new file mode 100644 index 000000000000..e86689ebe23a --- /dev/null +++ b/ffi/include/tvm/ffi/object.h @@ -0,0 +1,798 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/ffi/object.h + * \brief A managed object in the TVM FFI. + */ +#ifndef TVM_FFI_OBJECT_H_ +#define TVM_FFI_OBJECT_H_ + +#include +#include + +#include +#include +#include +#include + +namespace tvm { +namespace ffi { + +using TypeIndex = TVMFFITypeIndex; +using TypeInfo = TVMFFITypeInfo; + +/*! + * \brief Known type keys for pre-defined types. + */ +struct StaticTypeKey { + static constexpr const char* kTVMFFIAny = "Any"; + static constexpr const char* kTVMFFINone = "None"; + static constexpr const char* kTVMFFIBool = "bool"; + static constexpr const char* kTVMFFIInt = "int"; + static constexpr const char* kTVMFFIFloat = "float"; + static constexpr const char* kTVMFFIOpaquePtr = "void*"; + static constexpr const char* kTVMFFIDataType = "DataType"; + static constexpr const char* kTVMFFIDevice = "Device"; + static constexpr const char* kTVMFFIRawStr = "const char*"; + static constexpr const char* kTVMFFIByteArrayPtr = "TVMFFIByteArray*"; + static constexpr const char* kTVMFFIObjectRValueRef = "ObjectRValueRef"; + static constexpr const char* kTVMFFIBytes = "object.Bytes"; + static constexpr const char* kTVMFFIStr = "object.String"; + static constexpr const char* kTVMFFIShape = "object.Shape"; + static constexpr const char* kTVMFFINDArray = "object.NDArray"; +}; + +/*! + * \brief Get type key from type index + * \param type_index The input type index + * \return the type key + */ +inline std::string TypeIndexToTypeKey(int32_t type_index) { + const TypeInfo* type_info = TVMFFIGetTypeInfo(type_index); + return std::string(type_info->type_key.data, type_info->type_key.size); +} + +namespace details { +// Helper to perform +// unsafe operations related to object +struct ObjectUnsafe; + +/*! + * Check if the type_index is an instance of TargetObjectType. + * + * \tparam TargetType The target object type to be checked. + * + * \param object_type_index The type index to be checked, caller + * ensures that the index is already within the object index range. + * + * \return Whether the target type is true. + */ +template +TVM_FFI_INLINE bool IsObjectInstance(int32_t object_type_index); +} // namespace details + +/*! + * \brief base class of all object containers. + * + * Sub-class of objects should declare the following static constexpr fields: + * + * - _type_index: + * Static type index of the object, if assigned to TypeIndex::kTVMFFIDynObject + * the type index will be assigned during runtime. + * Runtime type index can be accessed by ObjectType::TypeIndex(); + * - _type_key: + * The unique string identifier of the type. + * - _type_final: + * Whether the type is terminal type(there is no subclass of the type in the object system). + * This field is automatically set by macro TVM_DECLARE_FINAL_OBJECT_INFO + * It is still OK to sub-class a terminal object type T and construct it using make_object. + * But IsInstance check will only show that the object type is T(instead of the sub-class). + * + * The following two fields are necessary for base classes that can be sub-classed. + * + * - _type_child_slots: + * Number of reserved type index slots for child classes. + * Used for runtime optimization for type checking in IsInstance. + * If an object's type_index is within range of [type_index, type_index + _type_child_slots] + * Then the object can be quickly decided as sub-class of the current object class. + * If not, a fallback mechanism is used to check the global type table. + * Recommendation: set to estimate number of children needed. + * + * - _type_child_slots_can_overflow: + * Whether we can add additional child classes even if the number of child classes + * exceeds the _type_child_slots. A fallback mechanism to check type table will be used. + * Recommendation: set to false for optimal runtime speed if we know exact number of children. + * + * Two macros are used to declare helper functions in the object: + * - Use TVM_FFI_DECLARE_BASE_OBJECT_INFO for object classes that can be sub-classed. + * - Use TVM_FFI_DECLARE_FINAL_OBJECT_INFO for object classes that cannot be sub-classed. + * + * New objects can be created using make_object function. + * Which will automatically populate the type_index and deleter of the object. + */ +class Object { + protected: + /*! \brief header field that is the common prefix of all objects */ + TVMFFIObject header_; + + public: + Object() { + header_.ref_counter = 0; + header_.deleter = nullptr; + } + /*! + * Check if the object is an instance of TargetType. + * \tparam TargetType The target type to be checked. + * \return Whether the target type is true. + */ + template + bool IsInstance() const { + return details::IsObjectInstance(header_.type_index); + } + + /*! \return The internal runtime type index of the object. */ + int32_t type_index() const { return header_.type_index; } + + /*! + * \return the type key of the object. + * \note this operation is expensive, can be used for error reporting. + */ + std::string GetTypeKey() const { + // the function checks that the info exists + const TypeInfo* type_info = TVMFFIGetTypeInfo(header_.type_index); + return std::string(type_info->type_key.data, type_info->type_key.size); + } + + /*! + * \return A hash value of the return of GetTypeKey. + */ + uint64_t GetTypeKeyHash() const { + // the function checks that the info exists + const TypeInfo* type_info = TVMFFIGetTypeInfo(header_.type_index); + return type_info->type_key_hash; + } + + /*! + * \brief Get the type key of the corresponding index from runtime. + * \param tindex The type index. + * \return the result. + */ + static std::string TypeIndex2Key(int32_t tindex) { + const TypeInfo* type_info = TVMFFIGetTypeInfo(tindex); + return std::string(type_info->type_key.data, type_info->type_key.size); + } + + bool unique() const { return use_count() == 1; } + + /*! + * \return The usage count of the cell. + * \note We use stl style naming to be consistent with known API in shared_ptr. + */ + int32_t use_count() const { return details::AtomicLoadRelaxed(&(header_.ref_counter)); } + + // Information about the object + static constexpr const char* _type_key = "object.Object"; + + // Default object type properties for sub-classes + static constexpr bool _type_final = false; + static constexpr uint32_t _type_child_slots = 0; + static constexpr bool _type_child_slots_can_overflow = true; + // NOTE: static type index field of the class + static constexpr int32_t _type_index = TypeIndex::kTVMFFIObject; + // the static type depth of the class + static constexpr int32_t _type_depth = 0; + // extra fields used by plug-ins for attribute visiting + // and structural information + static constexpr const bool _type_has_method_visit_attrs = true; + static constexpr const bool _type_has_method_sequal_reduce = false; + static constexpr const bool _type_has_method_shash_reduce = false; + // The following functions are provided by macro + // TVM_FFI_DECLARE_BASE_OBJECT_INFO and TVM_DECLARE_FINAL_OBJECT_INFO + /*! + * \brief Get the runtime allocated type index of the type + * \note Getting this information may need dynamic calls into a global table. + */ + static int32_t RuntimeTypeIndex() { return TypeIndex::kTVMFFIObject; } + /*! + * \brief Internal function to get or allocate a runtime index. + */ + static int32_t _GetOrAllocRuntimeTypeIndex() { return TypeIndex::kTVMFFIObject; } + + private: + /*! \brief increase reference count */ + void IncRef() { details::AtomicIncrementRelaxed(&(header_.ref_counter)); } + + /*! \brief decrease reference count and delete the object */ + void DecRef() { + if (details::AtomicDecrementRelAcq(&(header_.ref_counter)) == 1) { + if (header_.deleter != nullptr) { + header_.deleter(&(this->header_)); + } + } + } + + // friend classes + template + friend class ObjectPtr; + friend struct tvm::ffi::details::ObjectUnsafe; +}; + +/*! + * \brief A custom smart pointer for Object. + * \tparam T the content data type. + * \sa make_object + */ +template +class ObjectPtr { + public: + /*! \brief default constructor */ + ObjectPtr() {} + /*! \brief default constructor */ + ObjectPtr(std::nullptr_t) {} // NOLINT(*) + /*! + * \brief copy constructor + * \param other The value to be moved + */ + ObjectPtr(const ObjectPtr& other) // NOLINT(*) + : ObjectPtr(other.data_) {} + /*! + * \brief copy constructor + * \param other The value to be moved + */ + template + ObjectPtr(const ObjectPtr& other) // NOLINT(*) + : ObjectPtr(other.data_) { + static_assert(std::is_base_of::value, + "can only assign of child class ObjectPtr to parent"); + } + /*! + * \brief move constructor + * \param other The value to be moved + */ + ObjectPtr(ObjectPtr&& other) // NOLINT(*) + : data_(other.data_) { + other.data_ = nullptr; + } + /*! + * \brief move constructor + * \param other The value to be moved + */ + template + ObjectPtr(ObjectPtr&& other) // NOLINT(*) + : data_(other.data_) { + static_assert(std::is_base_of::value, + "can only assign of child class ObjectPtr to parent"); + other.data_ = nullptr; + } + /*! \brief destructor */ + ~ObjectPtr() { this->reset(); } + /*! + * \brief Swap this array with another Object + * \param other The other Object + */ + void swap(ObjectPtr& other) { // NOLINT(*) + std::swap(data_, other.data_); + } + /*! + * \return Get the content of the pointer + */ + T* get() const { return static_cast(data_); } + /*! + * \return The pointer + */ + T* operator->() const { return get(); } + /*! + * \return The reference + */ + T& operator*() const { // NOLINT(*) + return *get(); + } + /*! + * \brief copy assignment + * \param other The value to be assigned. + * \return reference to self. + */ + ObjectPtr& operator=(const ObjectPtr& other) { // NOLINT(*) + // takes in plane operator to enable copy elison. + // copy-and-swap idiom + ObjectPtr(other).swap(*this); // NOLINT(*) + return *this; + } + /*! + * \brief move assignment + * \param other The value to be assigned. + * \return reference to self. + */ + ObjectPtr& operator=(ObjectPtr&& other) { // NOLINT(*) + // copy-and-swap idiom + ObjectPtr(std::move(other)).swap(*this); // NOLINT(*) + return *this; + } + /*! + * \brief nullptr check + * \return result of comparison of internal pointer with nullptr. + */ + explicit operator bool() const { return get() != nullptr; } + /*! \brief reset the content of ptr to be nullptr */ + void reset() { + if (data_ != nullptr) { + data_->DecRef(); + data_ = nullptr; + } + } + /*! \return The use count of the ptr, for debug purposes */ + int use_count() const { return data_ != nullptr ? data_->use_count() : 0; } + /*! \return whether the reference is unique */ + bool unique() const { return data_ != nullptr && data_->use_count() == 1; } + /*! \return Whether two ObjectPtr do not equal each other */ + bool operator==(const ObjectPtr& other) const { return data_ == other.data_; } + /*! \return Whether two ObjectPtr equals each other */ + bool operator!=(const ObjectPtr& other) const { return data_ != other.data_; } + /*! \return Whether the pointer is nullptr */ + bool operator==(std::nullptr_t) const { return data_ == nullptr; } + /*! \return Whether the pointer is not nullptr */ + bool operator!=(std::nullptr_t) const { return data_ != nullptr; } + + private: + /*! \brief internal pointer field */ + Object* data_{nullptr}; + /*! + * \brief constructor from Object + * \param data The data pointer + */ + explicit ObjectPtr(Object* data) : data_(data) { + if (data_ != nullptr) { + data_->IncRef(); + } + } + // friend classes + friend class Object; + friend class ObjectRef; + friend struct ObjectPtrHash; + template + friend class ObjectPtr; + friend struct tvm::ffi::details::ObjectUnsafe; +}; + +/*! + * \brief Optional data type in FFI. + * \tparam T The underlying type of the optional. + * + * \note Compared to std::optional, Optional + * akes less storage as it used nullptr to represent nullopt. + */ +template +class Optional; + +/*! \brief Base class of all object reference */ +class ObjectRef { + public: + /*! \brief default constructor */ + ObjectRef() = default; + /*! \brief copy constructor */ + ObjectRef(const ObjectRef& other) = default; + /*! \brief move constructor */ + ObjectRef(ObjectRef&& other) = default; + /*! \brief copy assignment */ + ObjectRef& operator=(const ObjectRef& other) = default; + /*! \brief move assignment */ + ObjectRef& operator=(ObjectRef&& other) = default; + /*! \brief Constructor from existing object ptr */ + explicit ObjectRef(ObjectPtr data) : data_(data) {} + /*! + * \brief Comparator + * \param other Another object ref. + * \return the compare result. + */ + bool same_as(const ObjectRef& other) const { return data_ == other.data_; } + /*! + * \brief Comparator + * \param other Another object ref. + * \return the compare result. + */ + bool operator==(const ObjectRef& other) const { return data_ == other.data_; } + /*! + * \brief Comparator + * \param other Another object ref. + * \return the compare result. + */ + bool operator!=(const ObjectRef& other) const { return data_ != other.data_; } + /*! + * \brief Comparator + * \param other Another object ref by address. + * \return the compare result. + */ + bool operator<(const ObjectRef& other) const { return data_.get() < other.data_.get(); } + /*! + * \return whether the object is defined. + */ + bool defined() const { return data_ != nullptr; } + /*! \return the internal object pointer */ + const Object* get() const { return data_.get(); } + /*! \return the internal object pointer */ + const Object* operator->() const { return get(); } + /*! \return whether the reference is unique */ + bool unique() const { return data_.unique(); } + /*! \return The use count of the ptr, for debug purposes */ + int use_count() const { return data_.use_count(); } + + /*! + * \brief Try to downcast the internal Object to a + * raw pointer of a corresponding type. + * + * The function will return a nullptr if the cast failed. + * + * if (const AddNode *ptr = node_ref.as()) { + * // This is an add node + * } + * + * \tparam ObjectType the target type, must be a subtype of Object + * \return The pointer to the requested type. + */ + template >> + const ObjectType* as() const { + if (data_ != nullptr && data_->IsInstance()) { + return static_cast(data_.get()); + } else { + return nullptr; + } + } + + /*! + * \brief Try to downcast the ObjectRef to Optional of the requested type. + * + * The function will return a std::nullopt if the cast or if the pointer is nullptr. + * + * \tparam ObjectRefType the target type, must be a subtype of ObjectRef' + * \return The optional value of the requested type. + */ + template >> + TVM_FFI_INLINE std::optional as() const { + if (data_ != nullptr) { + if (data_->IsInstance()) { + return ObjectRefType(data_); + } else { + return std::nullopt; + } + } else { + return std::nullopt; + } + } + /*! + * \brief Get the type index of the ObjectRef + * \return The type index of the ObjectRef + */ + int32_t type_index() const { + return data_ != nullptr ? data_->type_index() : TypeIndex::kTVMFFINone; + } + + /*! + * \brief Get the type key of the ObjectRef + * \return The type key of the ObjectRef + */ + std::string GetTypeKey() const { + return data_ != nullptr ? data_->GetTypeKey() : StaticTypeKey::kTVMFFINone; + } + + /*! \brief type indicate the container type. */ + using ContainerType = Object; + // Default type properties for the reference class. + static constexpr bool _type_is_nullable = true; + + protected: + /*! \brief Internal pointer that backs the reference. */ + ObjectPtr data_; + /*! \return return a mutable internal ptr, can be used by sub-classes. */ + Object* get_mutable() const { return data_.get(); } + // friend classes. + friend struct ObjectPtrHash; + friend struct tvm::ffi::details::ObjectUnsafe; +}; + +// forward delcare variant +template +class Variant; + +/*! \brief ObjectRef hash functor */ +struct ObjectPtrHash { + size_t operator()(const ObjectRef& a) const { return operator()(a.data_); } + + template + size_t operator()(const ObjectPtr& a) const { + return std::hash()(a.get()); + } + + template + TVM_FFI_INLINE size_t operator()(const Variant& a) const; +}; + +/*! \brief ObjectRef equal functor */ +struct ObjectPtrEqual { + bool operator()(const ObjectRef& a, const ObjectRef& b) const { return a.same_as(b); } + + template + bool operator()(const ObjectPtr& a, const ObjectPtr& b) const { + return a == b; + } + + template + TVM_FFI_INLINE bool operator()(const Variant& a, const Variant& b) const; +}; + +// If dynamic type is enabled, we still need to register the runtime type of parent +#define TVM_FFI_REGISTER_STATIC_TYPE_INFO(TypeName, ParentType) \ + static constexpr int32_t _type_depth = ParentType::_type_depth + 1; \ + static int32_t _GetOrAllocRuntimeTypeIndex() { \ + static_assert(!ParentType::_type_final, "ParentType marked as final"); \ + static_assert(TypeName::_type_child_slots == 0 || ParentType::_type_child_slots == 0 || \ + TypeName::_type_child_slots < ParentType::_type_child_slots, \ + "Need to set _type_child_slots when parent specifies it."); \ + TVMFFIByteArray type_key{TypeName::_type_key, \ + std::char_traits::length(TypeName::_type_key)}; \ + static int32_t tindex = TVMFFIGetOrAllocTypeIndex( \ + &type_key, TypeName::_type_index, TypeName::_type_depth, TypeName::_type_child_slots, \ + TypeName::_type_child_slots_can_overflow, ParentType::_GetOrAllocRuntimeTypeIndex()); \ + return tindex; \ + } \ + static inline int32_t _register_type_index = _GetOrAllocRuntimeTypeIndex() + +/*! + * \brief Helper macro to declare a object that comes with static type index. + * \param TypeName The name of the current type. + * \param ParentType The name of the ParentType + */ +#define TVM_FFI_DECLARE_STATIC_OBJECT_INFO(TypeName, ParentType) \ + static int32_t RuntimeTypeIndex() { return TypeName::_type_index; } \ + TVM_FFI_REGISTER_STATIC_TYPE_INFO(TypeName, ParentType) + +/*! + * \brief helper macro to declare a base object type that can be inherited. + * \param TypeName The name of the current type. + * \param ParentType The name of the ParentType + */ +#define TVM_FFI_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \ + static constexpr int32_t _type_depth = ParentType::_type_depth + 1; \ + static int32_t _GetOrAllocRuntimeTypeIndex() { \ + static_assert(!ParentType::_type_final, "ParentType marked as final"); \ + static_assert(TypeName::_type_child_slots == 0 || ParentType::_type_child_slots == 0 || \ + TypeName::_type_child_slots < ParentType::_type_child_slots, \ + "Need to set _type_child_slots when parent specifies it."); \ + TVMFFIByteArray type_key{TypeName::_type_key, \ + std::char_traits::length(TypeName::_type_key)}; \ + static int32_t tindex = TVMFFIGetOrAllocTypeIndex( \ + &type_key, -1, TypeName::_type_depth, TypeName::_type_child_slots, \ + TypeName::_type_child_slots_can_overflow, ParentType::_GetOrAllocRuntimeTypeIndex()); \ + return tindex; \ + } \ + static int32_t RuntimeTypeIndex() { return _GetOrAllocRuntimeTypeIndex(); } \ + static inline int32_t _type_index = _GetOrAllocRuntimeTypeIndex() + +/*! + * \brief helper macro to declare type information in a final class. + * \param TypeName The name of the current type. + * \param ParentType The name of the ParentType + */ +#define TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType) \ + static const constexpr int _type_child_slots = 0; \ + static const constexpr bool _type_final = true; \ + TVM_FFI_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) + +/* + * \brief Define object reference methods. + * + * \param TypeName The object type name + * \param ParentType The parent type of the objectref + * \param ObjectName The type name of the object. + * + * \note This macro also defines the default constructor that puts the ObjectRef + * in undefined state initially. + */ +#define TVM_FFI_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ + TypeName() = default; \ + explicit TypeName(::tvm::ffi::ObjectPtr<::tvm::ffi::Object> n) : ParentType(n) {} \ + TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \ + const ObjectName* operator->() const { return static_cast(data_.get()); } \ + const ObjectName* get() const { return operator->(); } \ + using ContainerType = ObjectName + +/* + * \brief Define object reference methods do not have undefined state. + * + * \param TypeName The object type name + * \param ParentType The parent type of the objectref + * \param ObjectName The type name of the object. + */ +#define TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ + explicit TypeName(::tvm::ffi::ObjectPtr<::tvm::ffi::Object> n) : ParentType(n) {} \ + TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \ + const ObjectName* operator->() const { return static_cast(data_.get()); } \ + const ObjectName* get() const { return operator->(); } \ + static constexpr bool _type_is_nullable = false; \ + using ContainerType = ObjectName + +/* + * \brief Define object reference methods of whose content is mutable. + * \param TypeName The object type name + * \param ParentType The parent type of the objectref + * \param ObjectName The type name of the object. + * \note We recommend making objects immutable when possible. + * This macro is only reserved for objects that stores runtime states. + */ +#define TVM_FFI_DEFINE_MUTABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ + TypeName() = default; \ + TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ + explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \ + ObjectName* operator->() const { return static_cast(data_.get()); } \ + using ContainerType = ObjectName; + +/* + * \brief Define object reference methods that is both not nullable and mutable. + * + * \param TypeName The object type name + * \param ParentType The parent type of the objectref + * \param ObjectName The type name of the object. + */ +#define TVM_FFI_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ + explicit TypeName(::tvm::ffi::ObjectPtr<::tvm::ffi::Object> n) : ParentType(n) {} \ + TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ + ObjectName* operator->() const { return static_cast(data_.get()); } \ + ObjectName* get() const { return operator->(); } \ + static constexpr bool _type_is_nullable = false; \ + using ContainerType = ObjectName; + +namespace details { +template +TVM_FFI_INLINE bool IsObjectInstance(int32_t object_type_index) { + static_assert(std::is_base_of_v); + // Everything is a subclass of object. + if constexpr (std::is_same::value) return true; + + if constexpr (TargetType::_type_final) { + // if the target type is a final type + // then we only need to check the equivalence. + return object_type_index == TargetType::RuntimeTypeIndex(); + } + + // if target type is a non-leaf type + // Check if type index falls into the range of reserved slots. + int32_t target_type_index = TargetType::RuntimeTypeIndex(); + int32_t begin = target_type_index; + // The condition will be optimized by constant-folding. + if constexpr (TargetType::_type_child_slots != 0) { + // total_slots = child_slots + 1 (including self) + int32_t end = begin + TargetType::_type_child_slots + 1; + if (object_type_index >= begin && object_type_index < end) return true; + } else { + if (object_type_index == begin) return true; + } + if (!TargetType::_type_child_slots_can_overflow) return false; + // Invariance: parent index is always smaller than the child. + if (object_type_index < target_type_index) return false; + // Do a runtime lookup of type information + // the function checks that the info exists + const TypeInfo* type_info = TVMFFIGetTypeInfo(object_type_index); + return (type_info->type_depth > TargetType::_type_depth && + type_info->type_acenstors[TargetType::_type_depth] == target_type_index); +} + +/*! + * \brief Namespace to internally manipulate object class. + * \note These functions are only supposed to be used by internal + * implementations and not external users of the tvm::ffi + */ +struct ObjectUnsafe { + // NOTE: get ffi header from an object + static TVM_FFI_INLINE TVMFFIObject* GetHeader(const Object* src) { + return const_cast(&(src->header_)); + } + + template + static TVM_FFI_INLINE int64_t GetObjectOffsetToSubclass() { + return (reinterpret_cast(&(static_cast(nullptr)->header_)) - + reinterpret_cast(&(static_cast(nullptr)->header_))); + } + + template + static TVM_FFI_INLINE ObjectPtr ObjectPtrFromObjectRef(const ObjectRef& ref) { + if constexpr (std::is_same_v) { + return ref.data_; + } else { + return tvm::ffi::ObjectPtr(ref.data_.data_); + } + } + + template + static TVM_FFI_INLINE ObjectPtr ObjectPtrFromObjectRef(ObjectRef&& ref) { + if constexpr (std::is_same_v) { + return std::move(ref.data_); + } else { + return tvm::ffi::ObjectPtr(std::move(ref.data_.data_)); + } + } + + template + static TVM_FFI_INLINE ObjectPtr ObjectPtrFromOwned(Object* raw_ptr) { + tvm::ffi::ObjectPtr ptr; + ptr.data_ = raw_ptr; + return ptr; + } + + template + static TVM_FFI_INLINE ObjectPtr ObjectPtrFromOwned(TVMFFIObject* obj_ptr) { + return ObjectPtrFromOwned(reinterpret_cast(obj_ptr)); + } + + template + static TVM_FFI_INLINE T* RawObjectPtrFromUnowned(TVMFFIObject* obj_ptr) { + // NOTE: this is important to first cast to Object* + // then cast back to T* because objptr and tptr may not be the same + // depending on how sub-class allocates the space. + return static_cast(reinterpret_cast(obj_ptr)); + } + + // Create ObjectPtr from unowned ptr + template + static TVM_FFI_INLINE ObjectPtr ObjectPtrFromUnowned(Object* raw_ptr) { + return tvm::ffi::ObjectPtr(raw_ptr); + } + + template + static TVM_FFI_INLINE ObjectPtr ObjectPtrFromUnowned(TVMFFIObject* obj_ptr) { + return tvm::ffi::ObjectPtr(reinterpret_cast(obj_ptr)); + } + + static TVM_FFI_INLINE void DecRefObjectHandle(TVMFFIObjectHandle handle) { + reinterpret_cast(handle)->DecRef(); + } + + static TVM_FFI_INLINE void IncRefObjectHandle(TVMFFIObjectHandle handle) { + reinterpret_cast(handle)->IncRef(); + } + + static TVM_FFI_INLINE Object* RawObjectPtrFromObjectRef(const ObjectRef& src) { + return src.data_.data_; + } + + static TVM_FFI_INLINE TVMFFIObject* TVMFFIObjectPtrFromObjectRef(const ObjectRef& src) { + return GetHeader(src.data_.data_); + } + + template + static TVM_FFI_INLINE TVMFFIObject* TVMFFIObjectPtrFromObjectPtr(const ObjectPtr& src) { + return GetHeader(src.data_); + } + + template + static TVM_FFI_INLINE TVMFFIObject* MoveObjectPtrToTVMFFIObjectPtr(ObjectPtr&& src) { + Object* obj_ptr = src.data_; + src.data_ = nullptr; + return GetHeader(obj_ptr); + } + + static TVM_FFI_INLINE TVMFFIObject* MoveObjectRefToTVMFFIObjectPtr(ObjectRef&& src) { + Object* obj_ptr = src.data_.data_; + src.data_.data_ = nullptr; + return GetHeader(obj_ptr); + } +}; +} // namespace details +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_OBJECT_H_ diff --git a/ffi/include/tvm/ffi/optional.h b/ffi/include/tvm/ffi/optional.h new file mode 100644 index 000000000000..7b3f69ef9919 --- /dev/null +++ b/ffi/include/tvm/ffi/optional.h @@ -0,0 +1,299 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ffi/optional.h + * \brief Runtime Optional container types. + * \note Optional specializes for T is ObjectRef and used nullptr to indicate nullopt. + */ +#ifndef TVM_FFI_OPTIONAL_H_ +#define TVM_FFI_OPTIONAL_H_ + +#include +#include + +#include +#include +#include + +namespace tvm { +namespace ffi { + +// Note: We place optional in tvm/ffi instead of tvm/ffi/container +// because optional itself is an inherent core component of the FFI system. + +template +inline constexpr bool is_optional_type_v = false; + +template +inline constexpr bool is_optional_type_v> = true; + +// we can safely used ptr based optional for ObjectRef types +// that do not have additional data members and virtual functions. +template +inline constexpr bool use_ptr_based_optional_v = + (std::is_base_of_v && !is_optional_type_v); + +// Specialization for non-ObjectRef types. +// simply fallback to std::optional +template +class Optional>> { + public: + // default constructors. + Optional() = default; + Optional(const Optional& other) : data_(other.data_) {} + Optional(Optional&& other) : data_(std::move(other.data_)) {} + Optional(std::optional other) : data_(std::move(other)) {} // NOLINT(*) + Optional(std::nullopt_t) {} // NOLINT(*) + // normal value handling. + Optional(T other) // NOLINT(*) + : data_(std::move(other)) {} + + TVM_FFI_INLINE Optional& operator=(const Optional& other) { + data_ = other.data_; + return *this; + } + + TVM_FFI_INLINE Optional& operator=(Optional&& other) { + data_ = std::move(other.data_); + return *this; + } + + TVM_FFI_INLINE Optional& operator=(T other) { + data_ = std::move(other); + return *this; + } + + TVM_FFI_INLINE Optional& operator=(std::nullopt_t) { + data_ = std::nullopt; + return *this; + } + + TVM_FFI_INLINE const T& value() const& { + if (!data_.has_value()) { + TVM_FFI_THROW(RuntimeError) << "Back optional access"; + } + return *data_; + } + + TVM_FFI_INLINE T&& value() && { + if (!data_.has_value()) { + TVM_FFI_THROW(RuntimeError) << "Back optional access"; + } + return *std::move(data_); + } + + template > + TVM_FFI_INLINE T value_or(U&& default_value) const { + return data_.value_or(std::forward(default_value)); + } + + TVM_FFI_INLINE explicit operator bool() const noexcept { return data_.has_value(); } + + TVM_FFI_INLINE bool has_value() const noexcept { return data_.has_value(); } + + TVM_FFI_INLINE bool operator==(const Optional& other) const { return data_ == other.data_; } + + TVM_FFI_INLINE bool operator!=(const Optional& other) const { return data_ != other.data_; } + + template + TVM_FFI_INLINE bool operator==(const U& other) const { + return data_ == other; + } + template + TVM_FFI_INLINE bool operator!=(const U& other) const { + return data_ != other; + } + + /*! + * \brief Direct access to the value. + * \return the xvalue reference to the stored value. + * \note only use this function after checking has_value() + */ + TVM_FFI_INLINE T&& operator*() && noexcept { return *std::move(data_); } + /*! + * \brief Direct access to the value. + * \return the const reference to the stored value. + * \note only use this function after checking has_value() + */ + TVM_FFI_INLINE const T& operator*() const& noexcept { return *data_; } + + private: + std::optional data_; +}; + +// Specialization for ObjectRef types. +// nullptr is treated as std::nullopt. +template +class Optional>> : public ObjectRef { + public: + using ContainerType = typename T::ContainerType; + Optional() = default; + Optional(const Optional& other) : ObjectRef(other.data_) {} + Optional(Optional&& other) : ObjectRef(std::move(other.data_)) {} + explicit Optional(ObjectPtr ptr) : ObjectRef(ptr) {} + // nullopt hanlding + Optional(std::nullopt_t) {} // NOLINT(*) + + // handle conversion from std::optional + Optional(std::optional other) { // NOLINT(*) + if (other.has_value()) { + *this = *std::move(other); + } + } + // normal value handling. + Optional(T other) // NOLINT(*) + : ObjectRef(std::move(other)) {} + + TVM_FFI_INLINE Optional& operator=(T other) { + ObjectRef::operator=(std::move(other)); + return *this; + } + + TVM_FFI_INLINE Optional& operator=(const Optional& other) { + data_ = other.data_; + return *this; + } + + TVM_FFI_INLINE Optional& operator=(std::nullptr_t) { + data_ = nullptr; + return *this; + } + + TVM_FFI_INLINE Optional& operator=(Optional&& other) { + data_ = std::move(other.data_); + return *this; + } + + TVM_FFI_INLINE T value() const& { + if (data_ == nullptr) { + TVM_FFI_THROW(RuntimeError) << "Back optional access"; + } + return T(data_); + } + + TVM_FFI_INLINE T value() && { + if (data_ == nullptr) { + TVM_FFI_THROW(RuntimeError) << "Back optional access"; + } + return T(std::move(data_)); + } + + template > + TVM_FFI_INLINE T value_or(U&& default_value) const { + return data_ != nullptr ? T(data_) : T(std::forward(default_value)); + } + + TVM_FFI_INLINE explicit operator bool() const { return data_ != nullptr; } + + TVM_FFI_INLINE bool has_value() const { return data_ != nullptr; } + + /*! + * \brief Direct access to the value. + * \return the const reference to the stored value. + * \note only use this function after checking has_value() + */ + TVM_FFI_INLINE T operator*() const& noexcept { return T(data_); } + + /*! + * \brief Direct access to the value. + * \return the const reference to the stored value. + * \note only use this function after checking has_value() + */ + TVM_FFI_INLINE T operator*() && noexcept { return T(std::move(data_)); } + + TVM_FFI_INLINE bool operator==(std::nullptr_t) const noexcept { return !has_value(); } + TVM_FFI_INLINE bool operator!=(std::nullptr_t) const noexcept { return has_value(); } + + // operator overloadings + TVM_FFI_INLINE auto operator==(const Optional& other) const { + // support case where sub-class returns a symbolic ref type. + return EQToOptional(other); + } + TVM_FFI_INLINE auto operator!=(const Optional& other) const { return NEToOptional(other); } + + TVM_FFI_INLINE auto operator==(const std::optional& other) const { + // support case where sub-class returns a symbolic ref type. + return EQToOptional(other); + } + TVM_FFI_INLINE auto operator!=(const std::optional& other) const { + return NEToOptional(other); + } + + TVM_FFI_INLINE auto operator==(const T& other) const { + using RetType = decltype(value() == other); + if (same_as(other)) return RetType(true); + if (has_value()) return operator*() == other; + return RetType(false); + } + + TVM_FFI_INLINE auto operator!=(const T& other) const { return !(*this == other); } + + template + TVM_FFI_INLINE auto operator==(const U& other) const { + using RetType = decltype(value() == other); + if (!has_value()) return RetType(false); + return operator*() == other; + } + + template + TVM_FFI_INLINE auto operator!=(const U& other) const { + using RetType = decltype(value() != other); + if (!has_value()) return RetType(true); + return operator*() != other; + } + + /*! + * \return The internal object pointer with container type of T. + * \note This function do not perform not-null checking. + */ + TVM_FFI_INLINE const ContainerType* get() const { + return static_cast(data_.get()); + } + + private: + template + TVM_FFI_INLINE auto EQToOptional(const U& other) const { + // support case where sub-class returns a symbolic ref type. + using RetType = decltype(value() == other.value()); + if (same_as(other)) return RetType(true); + if (has_value() && other.has_value()) { + return operator*() == *other; + } else { + // one of them is nullptr. + return RetType(false); + } + } + + template + TVM_FFI_INLINE auto NEToOptional(const U& other) const { + // support case where sub-class returns a symbolic ref type. + using RetType = decltype(value() != other.value()); + if (same_as(other)) return RetType(false); + if (has_value() && other.has_value()) { + return operator*() != *other; + } else { + // one of them is nullptr. + return RetType(true); + } + } +}; +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_OPTIONAL_H_ diff --git a/ffi/include/tvm/ffi/reflection/reflection.h b/ffi/include/tvm/ffi/reflection/reflection.h new file mode 100644 index 000000000000..766b9b809958 --- /dev/null +++ b/ffi/include/tvm/ffi/reflection/reflection.h @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/ffi/reflection/reflection.h + * \brief Base reflection support to access object fields. + */ +#ifndef TVM_FFI_REFLECTION_REFLECTION_H_ +#define TVM_FFI_REFLECTION_REFLECTION_H_ + +#include +#include +#include +#include + +#include + +namespace tvm { +namespace ffi { +namespace details { + +/*! + * \brief Get the byte offset of a class member field. + * + * \tparam The original class. + * \tparam T the field type. + * + * \param field_ptr A class member pointer + * \returns The byteoffset + */ +template +inline int64_t GetFieldByteOffsetToObject(T Class::*field_ptr) { + int64_t field_offset_to_class = + reinterpret_cast(&(static_cast(nullptr)->*field_ptr)); + return field_offset_to_class - details::ObjectUnsafe::GetObjectOffsetToSubclass(); +} + +struct ReflectionDefFinish {}; + +class ReflectionDef { + public: + explicit ReflectionDef(int32_t type_index) : type_index_(type_index) {} + + template + ReflectionDef& def_readonly(const char* name, T Class::*field_ptr) { + RegisterField(name, field_ptr, true); + return *this; + } + + template + ReflectionDef& def_readwrite(const char* name, T Class::*field_ptr) { + RegisterField(name, field_ptr, false); + return *this; + } + + private: + template + void RegisterField(const char* name, T Class::*field_ptr, bool readonly) { + TVMFFIFieldInfo info; + info.name = TVMFFIByteArray{name, std::char_traits::length(name)}; + info.field_static_type_index = TypeToFieldStaticTypeIndex::value; + // store byte offset and setter, getter + // so the same setter can be reused for all the same type + info.byte_offset = GetFieldByteOffsetToObject(field_ptr); + info.readonly = readonly; + info.getter = FieldGetter; + info.setter = FieldSetter; + TVM_FFI_CHECK_SAFE_CALL(TVMFFIRegisterTypeField(type_index_, &info)); + } + + template + static int FieldGetter(void* field, TVMFFIAny* result) { + TVM_FFI_SAFE_CALL_BEGIN(); + *result = details::AnyUnsafe::MoveAnyToTVMFFIAny(Any(*reinterpret_cast(field))); + TVM_FFI_SAFE_CALL_END(); + } + + template + static int FieldSetter(void* field, const TVMFFIAny* value) { + TVM_FFI_SAFE_CALL_BEGIN(); + *reinterpret_cast(field) = AnyView::CopyFromTVMFFIAny(*value).cast(); + TVM_FFI_SAFE_CALL_END(); + } + + int32_t type_index_; +}; + +/*! + * \brief helper function to get reflection field info by type key and field name + */ +inline const TVMFFIFieldInfo* GetReflectionFieldInfo(std::string_view type_key, + const char* field_name) { + int32_t type_index; + TVMFFIByteArray type_key_array = {type_key.data(), type_key.size()}; + TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_array, &type_index)); + const TypeInfo* info = TVMFFIGetTypeInfo(type_index); + for (int32_t i = 0; i < info->num_fields; ++i) { + if (std::strncmp(info->fields[i].name.data, field_name, info->fields[i].name.size) == 0) { + return &(info->fields[i]); + } + } + TVM_FFI_THROW(RuntimeError) << "Cannot find field " << field_name << " in " << type_key; +} + +/*! + * \brief helper wrapper class to obtain a getter. + */ +class ReflectionFieldGetter { + public: + explicit ReflectionFieldGetter(const TVMFFIFieldInfo* field_info) : field_info_(field_info) {} + + Any operator()(const Object* obj_ptr) const { + Any result; + const void* addr = reinterpret_cast(obj_ptr) + field_info_->byte_offset; + TVM_FFI_CHECK_SAFE_CALL( + field_info_->getter(const_cast(addr), reinterpret_cast(&result))); + return result; + } + + Any operator()(const ObjectPtr& obj_ptr) const { return operator()(obj_ptr.get()); } + + Any operator()(const ObjectRef& obj) const { return operator()(obj.get()); } + + private: + const TVMFFIFieldInfo* field_info_; +}; + +#define TVM_FFI_REFLECTION_REG_VAR_DEF \ + static inline TVM_FFI_ATTRIBUTE_UNUSED ::tvm::ffi::details::ReflectionDef& __TVMFFIReflectionReg + +/*! + * helper macro to define a reflection definition for an object + */ +#define TVM_FFI_REFLECTION_DEF(TypeName) \ + TVM_FFI_STR_CONCAT(TVM_FFI_REFLECTION_REG_VAR_DEF, __COUNTER__) = \ + ::tvm::ffi::details::ReflectionDef(TypeName::_GetOrAllocRuntimeTypeIndex()) +} // namespace details +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_REFLECTION_REFLECTION_H_ diff --git a/ffi/include/tvm/ffi/rvalue_ref.h b/ffi/include/tvm/ffi/rvalue_ref.h new file mode 100644 index 000000000000..a4a363fa17cd --- /dev/null +++ b/ffi/include/tvm/ffi/rvalue_ref.h @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ffi/rvalue_ref.h + * \brief Helper class to define rvalue reference type. + */ +#ifndef TVM_FFI_RVALUE_REF_H_ +#define TVM_FFI_RVALUE_REF_H_ + +#include +#include + +#include +#include + +namespace tvm { +namespace ffi { + +/*! + * \brief Helper class to define rvalue reference type. + * + * By default, FFI pass all values by lvalue reference. + * + * However, we do allow users to intentionally mark a function parameter + * as RValueRef. In such cases, the caller can choose to pass parameter + * wrapped by RValueRef to the function. In which case the parameter + * can be directly moved by the callee. The caller can also choose to pass + * a normal lvalue to the function, in such case a copy will be triggered. + * + * To keep FFI checking overhead minimal, we do not handle case when rvalue + * is passed, but the callee did not declare the parameter as RValueRef. + * + * This design allows us to still leverage move semantics for parameters that + * need copy on write scenarios (and requires an unique copy). + * + * \code + * + * void Example() { + * auto append = Function::FromUnpacked([](RValueRef> ref, int val) -> Array { + * Array arr = *std::move(ref); + * assert(arr.unique()); + * arr.push_back(val); + * return arr; + * }); + * Array a = Array({1, 2}); + * // as we use rvalue ref to move a into append + * // we keep a single copy of the Array without creating new copies during copy-on-write + * a = append(RvalueRef(std::move(a)), 3); + * assert(a.size() == 3); + * } + * + * \endcode + */ +template >> +class RValueRef { + public: + /*! \brief only allow move constructor from rvalue of T */ + explicit RValueRef(TObjRef&& data) + : data_(details::ObjectUnsafe::ObjectPtrFromObjectRef(std::move(data))) {} + + /*! \brief return the data as rvalue */ + TObjRef operator*() && { return TObjRef(std::move(data_)); } + + private: + mutable ObjectPtr data_; + + template + friend struct TypeTraits; +}; + +template +inline constexpr bool use_default_type_traits_v> = false; + +template +struct TypeTraits> : public TypeTraitsBase { + static constexpr bool storage_enabled = false; + + static TVM_FFI_INLINE void CopyToAnyView(const RValueRef& src, TVMFFIAny* result) { + result->type_index = TypeIndex::kTVMFFIObjectRValueRef; + // store the address of the ObjectPtr, which allows us to move the value + // and set the original ObjectPtr to nullptr + result->v_ptr = &(src.data_); + } + + static TVM_FFI_INLINE std::string GetMismatchTypeInfo(const TVMFFIAny* src) { + if (src->type_index == TypeIndex::kTVMFFIObjectRValueRef) { + ObjectPtr* rvalue_ref = reinterpret_cast*>(src->v_ptr); + // object type does not match up, we need to try to convert the object + // in this case we do not move the original rvalue ref since conversion creates a copy + TVMFFIAny tmp_any; + tmp_any.type_index = rvalue_ref->get()->type_index(); + + tmp_any.v_obj = reinterpret_cast(rvalue_ref->get()); + return "RValueRef<" + TypeTraits::GetMismatchTypeInfo(&tmp_any) + ">"; + } else { + return TypeTraits::GetMismatchTypeInfo(src); + } + } + + static TVM_FFI_INLINE std::optional> TryConvertFromAnyView( + const TVMFFIAny* src) { + // first try rvalue conversion + if (src->type_index == TypeIndex::kTVMFFIObjectRValueRef) { + ObjectPtr* rvalue_ref = reinterpret_cast*>(src->v_ptr); + TVMFFIAny tmp_any; + tmp_any.type_index = rvalue_ref->get()->type_index(); + tmp_any.v_obj = reinterpret_cast(rvalue_ref->get()); + // fast path, storage type matches, direct move the rvalue ref + if (TypeTraits::CheckAnyStorage(&tmp_any)) { + return RValueRef(TObjRef(std::move(*rvalue_ref))); + } + if (std::optional opt = TypeTraits::TryConvertFromAnyView(&tmp_any)) { + // object type does not match up, we need to try to convert the object + // in this case we do not move the original rvalue ref since conversion creates a copy + return RValueRef(*std::move(opt)); + } + return std::nullopt; + } + // try lvalue conversion + if (std::optional opt = TypeTraits::TryConvertFromAnyView(src)) { + return RValueRef(*std::move(opt)); + } else { + return std::nullopt; + } + } + + static TVM_FFI_INLINE std::string TypeStr() { + return "RValueRef<" + TypeTraits::TypeStr() + ">"; + } +}; +} // namespace ffi +} // namespace tvm + +#endif // TVM_FFI_RVALUE_REF_H_ diff --git a/ffi/include/tvm/ffi/string.h b/ffi/include/tvm/ffi/string.h new file mode 100644 index 000000000000..1c22f10892ac --- /dev/null +++ b/ffi/include/tvm/ffi/string.h @@ -0,0 +1,662 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ffi/string.h + * \brief Runtime Bytes and String types. + */ +#ifndef TVM_FFI_STRING_H_ +#define TVM_FFI_STRING_H_ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +// Note: We place string in tvm/ffi instead of tvm/ffi/container +// because string itself needs special handling and is an inherent +// core component for return string handling. +// The following dependency relation holds +// any -> string -> object + +namespace tvm { +namespace ffi { + +/*! \brief Base class for bytes and string. */ +class BytesObjBase : public Object, public TVMFFIByteArray {}; + +/*! + * \brief An object representing bytes. + * \note We use separate object for bytes to follow python convention + * and indicate passing of raw bytes. + * Bytes can be converted from/to string. + */ +class BytesObj : public BytesObjBase { + public: + static constexpr const uint32_t _type_index = TypeIndex::kTVMFFIBytes; + static constexpr const char* _type_key = StaticTypeKey::kTVMFFIBytes; + static const constexpr bool _type_final = true; + TVM_FFI_DECLARE_STATIC_OBJECT_INFO(BytesObj, Object); +}; + +/*! \brief An object representing string. It's POD type. */ +class StringObj : public BytesObjBase { + public: + static constexpr const uint32_t _type_index = TypeIndex::kTVMFFIStr; + static constexpr const char* _type_key = StaticTypeKey::kTVMFFIStr; + static const constexpr bool _type_final = true; + TVM_FFI_DECLARE_STATIC_OBJECT_INFO(StringObj, Object); +}; + +namespace details { + +// String moved from std::string +// without having to trigger a copy +template +class BytesObjStdImpl : public Base { + public: + explicit BytesObjStdImpl(std::string other) : data_{other} { + this->data = data_.data(); + this->size = data_.size(); + } + + private: + std::string data_; +}; + +// inplace string allocation +template +TVM_FFI_INLINE ObjectPtr MakeInplaceBytes(const char* data, size_t length) { + ObjectPtr p = make_inplace_array_object(length + 1); + static_assert(alignof(Base) % alignof(char) == 0); + static_assert(sizeof(Base) % alignof(char) == 0); + char* dest_data = reinterpret_cast(p.get()) + sizeof(Base); + p->data = dest_data; + p->size = length; + std::memcpy(dest_data, data, length); + dest_data[length] = '\0'; + return p; +} +} // namespace details + +/*! + * \brief Managed reference of byte array. + */ +class Bytes : public ObjectRef { + public: + /*! + * \brief constructor from char [N] + * + * \param other a char array. + */ + Bytes(const char* data, size_t size) // NOLINT(*) + : ObjectRef(details::MakeInplaceBytes(data, size)) {} + /*! + * \brief constructor from char [N] + * + * \param other a char array. + */ + Bytes(TVMFFIByteArray bytes) // NOLINT(*) + : ObjectRef(details::MakeInplaceBytes(bytes.data, bytes.size)) {} + /*! + * \brief constructor from char [N] + * + * \param other a char array. + */ + Bytes(std::string other) // NOLINT(*) + : ObjectRef(make_object>(std::move(other))) {} + /*! + * \brief Swap this String with another string + * \param other The other string + */ + void swap(Bytes& other) { // NOLINT(*) + std::swap(data_, other.data_); + } + + template + Bytes& operator=(T&& other) { + // copy-and-swap idiom + Bytes(std::forward(other)).swap(*this); // NOLINT(*) + return *this; + } + /*! + * \brief Return the length of the string + * + * \return size_t string length + */ + size_t size() const { return get()->size; } + /*! + * \brief Return the data pointer + * + * \return const char* data pointer + */ + const char* data() const { return get()->data; } + /*! + * \brief Convert String to an std::string object + * + * \return std::string + */ + operator std::string() const { return std::string{get()->data, size()}; } + + TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Bytes, ObjectRef, BytesObj); + + private: + /*! + * \brief Compare two char sequence + * + * \param lhs Pointers to the char array to compare + * \param rhs Pointers to the char array to compare + * \param lhs_count Length of the char array to compare + * \param rhs_count Length of the char array to compare + * \return int zero if both char sequences compare equal. negative if this + * appear before other, positive otherwise. + */ + static int memncmp(const char* lhs, const char* rhs, size_t lhs_count, size_t rhs_count); + + friend struct AnyEqual; + friend class String; +}; + +/*! + * \brief Reference to string objects. + * + * \code + * + * // Example to create runtime String reference object from std::string + * std::string s = "hello world"; + * + * // You can create the reference from existing std::string + * String ref{std::move(s)}; + * + * // You can rebind the reference to another string. + * ref = std::string{"hello world2"}; + * + * // You can use the reference as hash map key + * std::unordered_map m; + * m[ref] = 1; + * + * // You can compare the reference object with other string objects + * assert(ref == "hello world", true); + * + * // You can convert the reference to std::string again + * string s2 = (string)ref; + * + * \endcode + */ +class String : public ObjectRef { + public: + String(nullptr_t) = delete; // NOLINT(*) + + /*! + * \brief constructor from char [N] + * + * \param other a char array. + */ + template + String(const char other[N]) // NOLINT(*) + : ObjectRef(details::MakeInplaceBytes(other, N)) {} + + /*! + * \brief constructor + */ + String() : String("") {} + + /*! + * \brief constructor from raw string + * + * \param other a char array. + */ + String(const char* other) // NOLINT(*) + : ObjectRef(details::MakeInplaceBytes(other, std::strlen(other))) {} + + /*! + * \brief constructor from raw string + * + * \param other a char array. + */ + String(const char* other, size_t size) // NOLINT(*) + : ObjectRef(details::MakeInplaceBytes(other, size)) {} + + /*! + * \brief Construct a new string object + * \param other The std::string object to be copied + */ + String(const std::string& other) // NOLINT(*) + : ObjectRef(details::MakeInplaceBytes(other.data(), other.size())) {} + + /*! + * \brief Construct a new string object + * \param other The std::string object to be moved + */ + String(std::string&& other) // NOLINT(*) + : ObjectRef(make_object>(std::move(other))) {} + /*! + * \brief Swap this String with another string + * \param other The other string + */ + void swap(String& other) { // NOLINT(*) + std::swap(data_, other.data_); + } + + template + String& operator=(T&& other) { + // copy-and-swap idiom + String(std::forward(other)).swap(*this); // NOLINT(*) + return *this; + } + + /*! + * \brief Compares this String object to other + * + * \param other The String to compare with. + * + * \return zero if both char sequences compare equal. negative if this appear + * before other, positive otherwise. + */ + int compare(const String& other) const { + return Bytes::memncmp(data(), other.data(), size(), other.size()); + } + + /*! + * \brief Compares this String object to other + * + * \param other The string to compare with. + * + * \return zero if both char sequences compare equal. negative if this appear + * before other, positive otherwise. + */ + int compare(const std::string& other) const { + return Bytes::memncmp(data(), other.data(), size(), other.size()); + } + + /*! + * \brief Compares this to other + * + * \param other The character array to compare with. + * + * \return zero if both char sequences compare equal. negative if this appear + * before other, positive otherwise. + */ + int compare(const char* other) const { + return Bytes::memncmp(data(), other, size(), std::strlen(other)); + } + + /*! + * \brief Returns a pointer to the char array in the string. + * + * \return const char* + */ + const char* c_str() const { return get()->data; } + + /*! + * \brief Return the length of the string + * + * \return size_t string length + */ + size_t size() const { + const auto* ptr = get(); + return ptr->size; + } + + /*! + * \brief Return the length of the string + * + * \return size_t string length + */ + size_t length() const { return size(); } + + /*! + * \brief Retun if the string is empty + * + * \return true if empty, false otherwise. + */ + bool empty() const { return size() == 0; } + + /*! + * \brief Read an element. + * \param pos The position at which to read the character. + * + * \return The char at position + */ + char at(size_t pos) const { + if (pos < size()) { + return data()[pos]; + } else { + throw std::out_of_range("tvm::String index out of bounds"); + } + } + + /*! + * \brief Return the data pointer + * + * \return const char* data pointer + */ + const char* data() const { return get()->data; } + + /*! + * \brief Convert String to an std::string object + * + * \return std::string + */ + operator std::string() const { return std::string{get()->data, size()}; } + + TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(String, ObjectRef, StringObj); + + private: + /*! + * \brief Concatenate two char sequences + * + * \param lhs Pointers to the lhs char array + * \param lhs_size The size of the lhs char array + * \param rhs Pointers to the rhs char array + * \param rhs_size The size of the rhs char array + * + * \return The concatenated char sequence + */ + static String Concat(const char* lhs, size_t lhs_size, const char* rhs, size_t rhs_size) { + std::string ret(lhs, lhs_size); + ret.append(rhs, rhs_size); + return String(ret); + } + + // Overload + operator + friend String operator+(const String& lhs, const String& rhs); + friend String operator+(const String& lhs, const std::string& rhs); + friend String operator+(const std::string& lhs, const String& rhs); + friend String operator+(const String& lhs, const char* rhs); + friend String operator+(const char* lhs, const String& rhs); +}; + +/*! \brief Convert TVMFFIByteArray to std::string_view */ +TVM_FFI_INLINE std::string_view ToStringView(TVMFFIByteArray str) { + return std::string_view(str.data, str.size); +} + +// const char*, requirement: not nullable, do not retain ownership +template +struct TypeTraits : public TypeTraitsBase { + // NOTE: only enable implicit conversion into AnyView + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIRawStr; + static constexpr bool storage_enabled = false; + + static TVM_FFI_INLINE void CopyToAnyView(const char src[N], TVMFFIAny* result) { + result->type_index = TypeIndex::kTVMFFIRawStr; + result->v_c_str = src; + } + + static TVM_FFI_INLINE void MoveToAny(const char src[N], TVMFFIAny* result) { + // when we need to move to any, convert to owned object first + ObjectRefTypeTraitsBase::MoveToAny(String(src), result); + } +}; + +template <> +struct TypeTraits : public TypeTraitsBase { + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIRawStr; + static constexpr bool storage_enabled = false; + + static TVM_FFI_INLINE void CopyToAnyView(const char* src, TVMFFIAny* result) { + TVM_FFI_ICHECK_NOTNULL(src); + result->type_index = TypeIndex::kTVMFFIRawStr; + result->v_c_str = src; + } + + static TVM_FFI_INLINE void MoveToAny(const char* src, TVMFFIAny* result) { + // when we need to move to any, convert to owned object first + ObjectRefTypeTraitsBase::MoveToAny(String(src), result); + } + // Do not allow const char* in a container, so we do not need CheckAnyStorage + static TVM_FFI_INLINE std::optional TryConvertFromAnyView(const TVMFFIAny* src) { + if (src->type_index == TypeIndex::kTVMFFIRawStr) { + return static_cast(src->v_c_str); + } + return std::nullopt; + } + + static TVM_FFI_INLINE std::string TypeStr() { return "const char*"; } +}; + +// TVMFFIByteArray, requirement: not nullable, do not retain ownership +template <> +struct TypeTraits : public TypeTraitsBase { + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIByteArrayPtr; + static constexpr bool storage_enabled = false; + + static TVM_FFI_INLINE void CopyToAnyView(TVMFFIByteArray* src, TVMFFIAny* result) { + TVM_FFI_ICHECK_NOTNULL(src); + result->type_index = TypeIndex::kTVMFFIByteArrayPtr; + TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result); + result->v_ptr = src; + } + + static TVM_FFI_INLINE void MoveToAny(TVMFFIByteArray* src, TVMFFIAny* result) { + ObjectRefTypeTraitsBase::MoveToAny(Bytes(*src), result); + } + + static TVM_FFI_INLINE std::optional TryConvertFromAnyView( + const TVMFFIAny* src) { + if (src->type_index == TypeIndex::kTVMFFIByteArrayPtr) { + return static_cast(src->v_ptr); + } + return std::nullopt; + } + + static TVM_FFI_INLINE std::string TypeStr() { return StaticTypeKey::kTVMFFIByteArrayPtr; } +}; + +template <> +inline constexpr bool use_default_type_traits_v = false; + +// specialize to enable implicit conversion from TVMFFIByteArray* +template <> +struct TypeTraits : public ObjectRefWithFallbackTraitsBase { + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIBytes; + static TVM_FFI_INLINE Bytes ConvertFallbackValue(TVMFFIByteArray* src) { return Bytes(*src); } +}; + +template <> +inline constexpr bool use_default_type_traits_v = false; + +// specialize to enable implicit conversion from const char* +template <> +struct TypeTraits : public ObjectRefWithFallbackTraitsBase { + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIStr; + static TVM_FFI_INLINE String ConvertFallbackValue(const char* src) { return String(src); } +}; + +template <> +inline constexpr bool use_default_type_traits_v = false; + +template <> +struct TypeTraits + : public FallbackOnlyTraitsBase { + static TVM_FFI_INLINE void CopyToAnyView(const std::string& src, TVMFFIAny* result) { + result->type_index = TypeIndex::kTVMFFIRawStr; + result->v_c_str = src.c_str(); + } + + static TVM_FFI_INLINE void MoveToAny(std::string src, TVMFFIAny* result) { + // when we need to move to any, convert to owned object first + ObjectRefTypeTraitsBase::MoveToAny(String(std::move(src)), result); + } + + static TVM_FFI_INLINE std::string TypeStr() { return "std::string"; } + + static TVM_FFI_INLINE std::string ConvertFallbackValue(const char* src) { + return std::string(src); + } + + static TVM_FFI_INLINE std::string ConvertFallbackValue(TVMFFIByteArray* src) { + return std::string(src->data, src->size); + } + + static TVM_FFI_INLINE std::string ConvertFallbackValue(Bytes src) { + return src.operator std::string(); + } + + static TVM_FFI_INLINE std::string ConvertFallbackValue(String src) { + return src.operator std::string(); + } +}; + +inline String operator+(const String& lhs, const String& rhs) { + size_t lhs_size = lhs.size(); + size_t rhs_size = rhs.size(); + return String::Concat(lhs.data(), lhs_size, rhs.data(), rhs_size); +} + +inline String operator+(const String& lhs, const std::string& rhs) { + size_t lhs_size = lhs.size(); + size_t rhs_size = rhs.size(); + return String::Concat(lhs.data(), lhs_size, rhs.data(), rhs_size); +} + +inline String operator+(const std::string& lhs, const String& rhs) { + size_t lhs_size = lhs.size(); + size_t rhs_size = rhs.size(); + return String::Concat(lhs.data(), lhs_size, rhs.data(), rhs_size); +} + +inline String operator+(const char* lhs, const String& rhs) { + size_t lhs_size = std::strlen(lhs); + size_t rhs_size = rhs.size(); + return String::Concat(lhs, lhs_size, rhs.data(), rhs_size); +} + +inline String operator+(const String& lhs, const char* rhs) { + size_t lhs_size = lhs.size(); + size_t rhs_size = std::strlen(rhs); + return String::Concat(lhs.data(), lhs_size, rhs, rhs_size); +} + +// Overload < operator +inline bool operator<(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) < 0; } + +inline bool operator<(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) > 0; } + +inline bool operator<(const String& lhs, const String& rhs) { return lhs.compare(rhs) < 0; } + +inline bool operator<(const String& lhs, const char* rhs) { return lhs.compare(rhs) < 0; } + +inline bool operator<(const char* lhs, const String& rhs) { return rhs.compare(lhs) > 0; } + +// Overload > operator +inline bool operator>(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) > 0; } + +inline bool operator>(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) < 0; } + +inline bool operator>(const String& lhs, const String& rhs) { return lhs.compare(rhs) > 0; } + +inline bool operator>(const String& lhs, const char* rhs) { return lhs.compare(rhs) > 0; } + +inline bool operator>(const char* lhs, const String& rhs) { return rhs.compare(lhs) < 0; } + +// Overload <= operator +inline bool operator<=(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) <= 0; } + +inline bool operator<=(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) >= 0; } + +inline bool operator<=(const String& lhs, const String& rhs) { return lhs.compare(rhs) <= 0; } + +inline bool operator<=(const String& lhs, const char* rhs) { return lhs.compare(rhs) <= 0; } + +inline bool operator<=(const char* lhs, const String& rhs) { return rhs.compare(lhs) >= 0; } + +// Overload >= operator +inline bool operator>=(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) >= 0; } + +inline bool operator>=(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) <= 0; } + +inline bool operator>=(const String& lhs, const String& rhs) { return lhs.compare(rhs) >= 0; } + +inline bool operator>=(const String& lhs, const char* rhs) { return lhs.compare(rhs) >= 0; } + +inline bool operator>=(const char* lhs, const String& rhs) { return rhs.compare(lhs) <= 0; } + +// Overload == operator +inline bool operator==(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) == 0; } + +inline bool operator==(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) == 0; } + +inline bool operator==(const String& lhs, const String& rhs) { return lhs.compare(rhs) == 0; } + +inline bool operator==(const String& lhs, const char* rhs) { return lhs.compare(rhs) == 0; } + +inline bool operator==(const char* lhs, const String& rhs) { return rhs.compare(lhs) == 0; } + +// Overload != operator +inline bool operator!=(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) != 0; } + +inline bool operator!=(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) != 0; } + +inline bool operator!=(const String& lhs, const String& rhs) { return lhs.compare(rhs) != 0; } + +inline bool operator!=(const String& lhs, const char* rhs) { return lhs.compare(rhs) != 0; } + +inline bool operator!=(const char* lhs, const String& rhs) { return rhs.compare(lhs) != 0; } + +inline std::ostream& operator<<(std::ostream& out, const String& input) { + out.write(input.data(), input.size()); + return out; +} + +inline int Bytes::memncmp(const char* lhs, const char* rhs, size_t lhs_count, size_t rhs_count) { + if (lhs == rhs && lhs_count == rhs_count) return 0; + + for (size_t i = 0; i < lhs_count && i < rhs_count; ++i) { + if (lhs[i] < rhs[i]) return -1; + if (lhs[i] > rhs[i]) return 1; + } + if (lhs_count < rhs_count) { + return -1; + } else if (lhs_count > rhs_count) { + return 1; + } else { + return 0; + } +} +} // namespace ffi +} // namespace tvm + +namespace std { + +template <> +struct hash<::tvm::ffi::Bytes> { + std::size_t operator()(const ::tvm::ffi::Bytes& bytes) const { + return ::tvm::ffi::details::StableHashBytes(bytes.data(), bytes.size()); + } +}; + +template <> +struct hash<::tvm::ffi::String> { + std::size_t operator()(const ::tvm::ffi::String& str) const { + return ::tvm::ffi::details::StableHashBytes(str.data(), str.size()); + } +}; +} // namespace std +#endif // TVM_FFI_STRING_H_ diff --git a/ffi/include/tvm/ffi/type_traits.h b/ffi/include/tvm/ffi/type_traits.h new file mode 100644 index 000000000000..d350aea82ac7 --- /dev/null +++ b/ffi/include/tvm/ffi/type_traits.h @@ -0,0 +1,683 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/ffi/object.h + * \brief A managed object in the TVM FFI. + */ +#ifndef TVM_FFI_TYPE_TRAITS_H_ +#define TVM_FFI_TYPE_TRAITS_H_ + +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace tvm { +namespace ffi { + +/*! + * \brief TypeTraits that specifies the conversion behavior from/to FFI Any. + * + * The function specifications of TypeTraits + * + * - CopyToAnyView: Convert a value T to AnyView + * - MoveToAny: Move a value to Any + * - CheckAnyStorage: Check if a Any stores a result of MoveToAny of current T. + * - CopyFromAnyStorageAfterCheck: Copy a value T from Any storage after we pass CheckAnyStorage. + * - MoveFromAnyStorageAfterCheck: Move a value T from Any storage after we pass CheckAnyStorage. + * - TryConvertFromAnyView: Convert a AnyView to a T, we may apply type conversion. + * - GetMismatchTypeInfo: Get the type key of a type when TryConvertFromAnyView fails. + * - TypeStr: Get the type key of a type + * + * It is possible that CheckAnyStorage is false but TryConvertFromAnyView still works. + * + * For example, when Any x stores int, TypeTraits::CheckAnyStorage(x) will be false, + * but TypeTraits::TryConvertFromAnyView(x) will return a corresponding float value + * via type conversion. + * + * CheckAnyStorage is mainly used in recursive container such as Array to + * decide if a new Array needed to be created via recursive conversion, + * or we can use the current container as is when converting to Array. + * + * A container array: Array satisfies the following invariant: + * - `all(TypeTraits::CheckAnyStorage(x) for x in the array)`. + */ +template +struct TypeTraits { + /*! \brief Whether the type is enabled in FFI. */ + static constexpr bool convert_enabled = false; + /*! \brief Whether the type can appear as a storage type in Container */ + static constexpr bool storage_enabled = false; +}; + +/*! + * \brief TypeTraits that removes const and reference keywords. + * \tparam T the original type + */ +template +using TypeTraitsNoCR = TypeTraits>>; + +template +inline constexpr bool use_default_type_traits_v = true; + +struct TypeTraitsBase { + static constexpr bool convert_enabled = true; + static constexpr bool storage_enabled = true; + // get mismatched type when result mismatches the trait. + // this function is called after TryConvertFromAnyView fails + // to get more detailed type information in runtime + // especially when the error involves nested container type + static TVM_FFI_INLINE std::string GetMismatchTypeInfo(const TVMFFIAny* source) { + return TypeIndexToTypeKey(source->type_index); + } +}; + +template +struct TypeToFieldStaticTypeIndex { + static constexpr int32_t value = TypeIndex::kTVMFFIAny; +}; + +template +struct TypeToFieldStaticTypeIndex::convert_enabled>> { + static constexpr int32_t value = TypeTraits::field_static_type_index; +}; + +template +struct TypeToRuntimeTypeIndex { + static int32_t v() { return TypeToFieldStaticTypeIndex::value; } +}; + +template +struct TypeToRuntimeTypeIndex>> { + static int32_t v() { return T::ContainerType::RuntimeTypeIndex(); } +}; + +// None +template <> +struct TypeTraits : public TypeTraitsBase { + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFINone; + + static TVM_FFI_INLINE void CopyToAnyView(const std::nullptr_t&, TVMFFIAny* result) { + result->type_index = TypeIndex::kTVMFFINone; + // invariant: the pointer field also equals nullptr + // this will simplify same_as comparisons and hash + result->v_int64 = 0; + } + + static TVM_FFI_INLINE void MoveToAny(std::nullptr_t, TVMFFIAny* result) { + result->type_index = TypeIndex::kTVMFFINone; + // invariant: the pointer field also equals nullptr + // this will simplify same_as comparisons and hash + result->v_int64 = 0; + } + + static TVM_FFI_INLINE bool CheckAnyStorage(const TVMFFIAny* src) { + return src->type_index == TypeIndex::kTVMFFINone; + } + + static TVM_FFI_INLINE std::nullptr_t CopyFromAnyStorageAfterCheck(const TVMFFIAny*) { + return nullptr; + } + + static TVM_FFI_INLINE std::nullptr_t MoveFromAnyStorageAfterCheck(TVMFFIAny*) { return nullptr; } + + static TVM_FFI_INLINE std::optional TryConvertFromAnyView(const TVMFFIAny* src) { + if (src->type_index == TypeIndex::kTVMFFINone) { + return nullptr; + } + return std::nullopt; + } + + static TVM_FFI_INLINE std::string TypeStr() { return StaticTypeKey::kTVMFFINone; } +}; + +/** + * \brief A type that forbids implicit conversion from int to bool + * + * This type is used to prevent implicit conversion from int to bool. + */ +class StrictBool { + public: + StrictBool(bool value) : value_(value) {} // NOLINT(*) + operator bool() const { return value_; } + + private: + bool value_; +}; + +template <> +struct TypeTraits : public TypeTraitsBase { + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIBool; + + static TVM_FFI_INLINE void CopyToAnyView(const StrictBool& src, TVMFFIAny* result) { + result->type_index = TypeIndex::kTVMFFIBool; + result->v_int64 = static_cast(src); + } + + static TVM_FFI_INLINE void MoveToAny(StrictBool src, TVMFFIAny* result) { + CopyToAnyView(src, result); + } + + static TVM_FFI_INLINE bool CheckAnyStorage(const TVMFFIAny* src) { + return src->type_index == TypeIndex::kTVMFFIBool; + } + + static TVM_FFI_INLINE StrictBool CopyFromAnyStorageAfterCheck(const TVMFFIAny* src) { + return static_cast(src->v_int64); + } + + static TVM_FFI_INLINE StrictBool MoveFromAnyStorageAfterCheck(TVMFFIAny* src) { + // POD type, we can just copy the value + return CopyFromAnyStorageAfterCheck(src); + } + + static TVM_FFI_INLINE std::optional TryConvertFromAnyView(const TVMFFIAny* src) { + if (src->type_index == TypeIndex::kTVMFFIBool) { + return StrictBool(static_cast(src->v_int64)); + } + return std::nullopt; + } + + static TVM_FFI_INLINE std::string TypeStr() { return StaticTypeKey::kTVMFFIBool; } +}; + +// Bool type, allow implicit casting from int +template <> +struct TypeTraits : public TypeTraitsBase { + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIBool; + + static TVM_FFI_INLINE void CopyToAnyView(const bool& src, TVMFFIAny* result) { + result->type_index = TypeIndex::kTVMFFIBool; + result->v_int64 = static_cast(src); + } + + static TVM_FFI_INLINE void MoveToAny(bool src, TVMFFIAny* result) { CopyToAnyView(src, result); } + + static TVM_FFI_INLINE bool CheckAnyStorage(const TVMFFIAny* src) { + return src->type_index == TypeIndex::kTVMFFIBool; + } + + static TVM_FFI_INLINE bool CopyFromAnyStorageAfterCheck(const TVMFFIAny* src) { + return static_cast(src->v_int64); + } + + static TVM_FFI_INLINE bool MoveFromAnyStorageAfterCheck(TVMFFIAny* src) { + // POD type, we can just copy the value + return CopyFromAnyStorageAfterCheck(src); + } + + static TVM_FFI_INLINE std::optional TryConvertFromAnyView(const TVMFFIAny* src) { + if (src->type_index == TypeIndex::kTVMFFIInt || src->type_index == TypeIndex::kTVMFFIBool) { + return static_cast(src->v_int64); + } + return std::nullopt; + } + + static TVM_FFI_INLINE std::string TypeStr() { return StaticTypeKey::kTVMFFIBool; } +}; + +// Integer POD values +template +struct TypeTraits>> : public TypeTraitsBase { + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIInt; + + static TVM_FFI_INLINE void CopyToAnyView(const Int& src, TVMFFIAny* result) { + result->type_index = TypeIndex::kTVMFFIInt; + result->v_int64 = static_cast(src); + } + + static TVM_FFI_INLINE void MoveToAny(Int src, TVMFFIAny* result) { CopyToAnyView(src, result); } + + static TVM_FFI_INLINE bool CheckAnyStorage(const TVMFFIAny* src) { + // NOTE: CheckAnyStorage is always strict and should be consistent with MoveToAny + return src->type_index == TypeIndex::kTVMFFIInt; + } + + static TVM_FFI_INLINE Int CopyFromAnyStorageAfterCheck(const TVMFFIAny* src) { + return static_cast(src->v_int64); + } + + static TVM_FFI_INLINE Int MoveFromAnyStorageAfterCheck(TVMFFIAny* src) { + // POD type, we can just copy the value + return CopyFromAnyStorageAfterCheck(src); + } + + static TVM_FFI_INLINE std::optional TryConvertFromAnyView(const TVMFFIAny* src) { + if (src->type_index == TypeIndex::kTVMFFIInt || src->type_index == TypeIndex::kTVMFFIBool) { + return Int(src->v_int64); + } + return std::nullopt; + } + + static TVM_FFI_INLINE std::string TypeStr() { return StaticTypeKey::kTVMFFIInt; } +}; + +// Float POD values +template +struct TypeTraits>> + : public TypeTraitsBase { + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIFloat; + + static TVM_FFI_INLINE void CopyToAnyView(const Float& src, TVMFFIAny* result) { + result->type_index = TypeIndex::kTVMFFIFloat; + result->v_float64 = static_cast(src); + } + + static TVM_FFI_INLINE void MoveToAny(Float src, TVMFFIAny* result) { CopyToAnyView(src, result); } + + static TVM_FFI_INLINE bool CheckAnyStorage(const TVMFFIAny* src) { + // NOTE: CheckAnyStorage is always strict and should be consistent with MoveToAny + return src->type_index == TypeIndex::kTVMFFIFloat; + } + + static TVM_FFI_INLINE Float CopyFromAnyStorageAfterCheck(const TVMFFIAny* src) { + return static_cast(src->v_float64); + } + + static TVM_FFI_INLINE Float MoveFromAnyStorageAfterCheck(TVMFFIAny* src) { + // POD type, we can just copy the value + return CopyFromAnyStorageAfterCheck(src); + } + + static TVM_FFI_INLINE std::optional TryConvertFromAnyView(const TVMFFIAny* src) { + if (src->type_index == TypeIndex::kTVMFFIFloat) { + return Float(src->v_float64); + } else if (src->type_index == TypeIndex::kTVMFFIInt || + src->type_index == TypeIndex::kTVMFFIBool) { + return Float(src->v_int64); + } + return std::nullopt; + } + + static TVM_FFI_INLINE std::string TypeStr() { return StaticTypeKey::kTVMFFIFloat; } +}; + +// void* +template <> +struct TypeTraits : public TypeTraitsBase { + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIOpaquePtr; + + static TVM_FFI_INLINE void CopyToAnyView(void* src, TVMFFIAny* result) { + result->type_index = TypeIndex::kTVMFFIOpaquePtr; + TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result); + result->v_ptr = src; + } + + static TVM_FFI_INLINE void MoveToAny(void* src, TVMFFIAny* result) { CopyToAnyView(src, result); } + + static TVM_FFI_INLINE bool CheckAnyStorage(const TVMFFIAny* src) { + // NOTE: CheckAnyStorage is always strict and should be consistent with MoveToAny + return src->type_index == TypeIndex::kTVMFFIOpaquePtr; + } + + static TVM_FFI_INLINE void* CopyFromAnyStorageAfterCheck(const TVMFFIAny* src) { + return src->v_ptr; + } + + static TVM_FFI_INLINE void* MoveFromAnyStorageAfterCheck(TVMFFIAny* src) { + // POD type, we can just copy the value + return CopyFromAnyStorageAfterCheck(src); + } + + static TVM_FFI_INLINE std::optional TryConvertFromAnyView(const TVMFFIAny* src) { + if (src->type_index == TypeIndex::kTVMFFIOpaquePtr) { + return static_cast(src->v_ptr); + } + if (src->type_index == TypeIndex::kTVMFFINone) { + return static_cast(nullptr); + } + return std::nullopt; + } + + static TVM_FFI_INLINE std::string TypeStr() { return StaticTypeKey::kTVMFFIOpaquePtr; } +}; + +// Device +template <> +struct TypeTraits : public TypeTraitsBase { + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIDevice; + + static TVM_FFI_INLINE void CopyToAnyView(const DLDevice& src, TVMFFIAny* result) { + result->type_index = TypeIndex::kTVMFFIDevice; + result->v_device = src; + } + + static TVM_FFI_INLINE void MoveToAny(DLDevice src, TVMFFIAny* result) { + result->type_index = TypeIndex::kTVMFFIDevice; + result->v_device = src; + } + + static TVM_FFI_INLINE bool CheckAnyStorage(const TVMFFIAny* src) { + return src->type_index == TypeIndex::kTVMFFIDevice; + } + + static TVM_FFI_INLINE DLDevice CopyFromAnyStorageAfterCheck(const TVMFFIAny* src) { + return src->v_device; + } + + static TVM_FFI_INLINE DLDevice MoveFromAnyStorageAfterCheck(TVMFFIAny* src) { + // POD type, we can just copy the value + return CopyFromAnyStorageAfterCheck(src); + } + + static TVM_FFI_INLINE std::optional TryConvertFromAnyView(const TVMFFIAny* src) { + if (src->type_index == TypeIndex::kTVMFFIDevice) { + return src->v_device; + } + return std::nullopt; + } + + static TVM_FFI_INLINE std::string TypeStr() { return StaticTypeKey::kTVMFFIDevice; } +}; + +// DLTensor*, requirement: not nullable, do not retain ownership +template <> +struct TypeTraits : public TypeTraitsBase { + static constexpr bool storage_enabled = false; + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIDLTensorPtr; + + static TVM_FFI_INLINE void CopyToAnyView(DLTensor* src, TVMFFIAny* result) { + TVM_FFI_ICHECK_NOTNULL(src); + result->type_index = TypeIndex::kTVMFFIDLTensorPtr; + TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result); + result->v_ptr = src; + } + + static TVM_FFI_INLINE void MoveToAny(DLTensor*, TVMFFIAny*) { + TVM_FFI_THROW(RuntimeError) + << "DLTensor* cannot be held in Any as it does not retain ownership, use NDArray instead"; + } + + static TVM_FFI_INLINE std::optional TryConvertFromAnyView(const TVMFFIAny* src) { + if (src->type_index == TypeIndex::kTVMFFIDLTensorPtr) { + return static_cast(src->v_ptr); + } else if (src->type_index == TypeIndex::kTVMFFINDArray) { + // Conversion from NDArray pointer to DLTensor + // based on the assumption that NDArray always follows the TVMFFIObject header + static_assert(sizeof(TVMFFIObject) == 16, "TVMFFIObject must be 8 bytes"); + return reinterpret_cast(reinterpret_cast(src->v_obj) + + sizeof(TVMFFIObject)); + } + return std::nullopt; + } + + static TVM_FFI_INLINE std::string TypeStr() { return "DLTensor*"; } +}; + +// Traits for ObjectRef, None to ObjectRef will always fail. +// use std::optional instead for nullable references. +template +struct ObjectRefTypeTraitsBase : public TypeTraitsBase { + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIObject; + using ContainerType = typename TObjRef::ContainerType; + + static TVM_FFI_INLINE void CopyToAnyView(const TObjRef& src, TVMFFIAny* result) { + if constexpr (TObjRef::_type_is_nullable) { + if (!src.defined()) { + TypeTraits::CopyToAnyView(nullptr, result); + return; + } + } + TVMFFIObject* obj_ptr = details::ObjectUnsafe::TVMFFIObjectPtrFromObjectRef(src); + result->type_index = obj_ptr->type_index; + TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result); + result->v_obj = obj_ptr; + } + + static TVM_FFI_INLINE void MoveToAny(TObjRef src, TVMFFIAny* result) { + if constexpr (TObjRef::_type_is_nullable) { + if (!src.defined()) { + TypeTraits::CopyToAnyView(nullptr, result); + return; + } + } + TVMFFIObject* obj_ptr = details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(src)); + result->type_index = obj_ptr->type_index; + TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result); + result->v_obj = obj_ptr; + } + + static TVM_FFI_INLINE bool CheckAnyStorage(const TVMFFIAny* src) { + if constexpr (TObjRef::_type_is_nullable) { + if (src->type_index == TypeIndex::kTVMFFINone) return true; + } + return (src->type_index >= TypeIndex::kTVMFFIStaticObjectBegin && + details::IsObjectInstance(src->type_index)); + } + + static TVM_FFI_INLINE TObjRef CopyFromAnyStorageAfterCheck(const TVMFFIAny* src) { + if constexpr (TObjRef::_type_is_nullable) { + if (src->type_index == TypeIndex::kTVMFFINone) { + return TObjRef(ObjectPtr(nullptr)); + } + } + return TObjRef(details::ObjectUnsafe::ObjectPtrFromUnowned(src->v_obj)); + } + + static TVM_FFI_INLINE TObjRef MoveFromAnyStorageAfterCheck(TVMFFIAny* src) { + if constexpr (TObjRef::_type_is_nullable) { + if (src->type_index == TypeIndex::kTVMFFINone) { + return TObjRef(ObjectPtr(nullptr)); + } + } + // move out the object pointer + ObjectPtr obj_ptr = details::ObjectUnsafe::ObjectPtrFromOwned(src->v_obj); + // reset the src to nullptr + TypeTraits::MoveToAny(nullptr, src); + return TObjRef(std::move(obj_ptr)); + } + + static TVM_FFI_INLINE std::optional TryConvertFromAnyView(const TVMFFIAny* src) { + if constexpr (TObjRef::_type_is_nullable) { + if (src->type_index == TypeIndex::kTVMFFINone) { + return TObjRef(ObjectPtr(nullptr)); + } + } + if (src->type_index >= TypeIndex::kTVMFFIStaticObjectBegin) { + if (details::IsObjectInstance(src->type_index)) { + return TObjRef(details::ObjectUnsafe::ObjectPtrFromUnowned(src->v_obj)); + } + } + return std::nullopt; + } + + static TVM_FFI_INLINE std::string TypeStr() { return ContainerType::_type_key; } +}; + +template +struct TypeTraits && + use_default_type_traits_v>> + : public ObjectRefTypeTraitsBase {}; + +/*! + * \brief Helper class that convert to T only via the FallbackTypes + * + * The conversion will go through the FallbackTypes in the order + * specified in the template parameter. + * \tparam T The type of the target value. + * \tparam FallbackTypes The type of the fallback value. + * \note TypeTraits must be derived from this class and define + * ConvertFallbackValue(FallbackType)->T for each FallbackType + */ +template +struct FallbackOnlyTraitsBase : public TypeTraitsBase { + // disable container for FallbackOnlyTraitsBase + static constexpr bool storage_enabled = false; + + static TVM_FFI_INLINE std::optional TryConvertFromAnyView(const TVMFFIAny* src) { + return TryFallbackTypes(src); + } + + template + static TVM_FFI_INLINE std::optional TryFallbackTypes(const TVMFFIAny* src) { + static_assert(!std::is_same_v, + "Using bool as FallbackType can cause bug because int will be detected as bool, " + "use tvm::ffi::StrictBool instead"); + if (auto opt_fallback = TypeTraits::TryConvertFromAnyView(src)) { + return TypeTraits::ConvertFallbackValue(*std::move(opt_fallback)); + } + if constexpr (sizeof...(Rest) > 0) { + return TryFallbackTypes(src); + } + return std::nullopt; + } +}; + +/*! + * \brief Helper class to define ObjectRef that can be auto-converted from a + * fallback type, the Traits must be derived from it + * and define a static methods named ConvertFallbackValue for each + * FallbackType + * + * The conversion will go through the FallbackTypes in the order + * specified in the template parameter. + * \tparam ObjectRefType The type of the ObjectRef. + * \tparam FallbackTypes The type of the fallback value. + */ +template +struct ObjectRefWithFallbackTraitsBase : public ObjectRefTypeTraitsBase { + static TVM_FFI_INLINE std::optional TryConvertFromAnyView(const TVMFFIAny* src) { + if (auto opt_obj = ObjectRefTypeTraitsBase::TryConvertFromAnyView(src)) { + return opt_obj.value(); + } + // apply fallback types in TryConvertFromAnyView + return TryFallbackTypes(src); + } + + template + static TVM_FFI_INLINE std::optional TryFallbackTypes(const TVMFFIAny* src) { + static_assert(!std::is_same_v, + "Using bool as FallbackType can cause bug because int will be detected as bool, " + "use tvm::ffi::StrictBool instead"); + if (auto opt_fallback = TypeTraits::TryConvertFromAnyView(src)) { + return TypeTraits::ConvertFallbackValue(*std::move(opt_fallback)); + } + if constexpr (sizeof...(Rest) > 0) { + return TryFallbackTypes(src); + } + return std::nullopt; + } +}; + +// Traits for weak pointer of object +// NOTE: we require the weak pointer cast from +template +struct TypeTraits>> + : public TypeTraitsBase { + static TVM_FFI_INLINE void CopyToAnyView(const TObject* src, TVMFFIAny* result) { + TVMFFIObject* obj_ptr = details::ObjectUnsafe::GetHeader(src); + result->type_index = obj_ptr->type_index; + TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result); + result->v_obj = obj_ptr; + } + + static TVM_FFI_INLINE void MoveToAny(const TObject* src, TVMFFIAny* result) { + TVMFFIObject* obj_ptr = details::ObjectUnsafe::GetHeader(src); + result->type_index = obj_ptr->type_index; + TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result); + result->v_obj = obj_ptr; + // needs to increase ref because original weak ptr do not own the code + details::ObjectUnsafe::IncRefObjectHandle(result->v_obj); + } + + static TVM_FFI_INLINE bool CheckAnyStorage(const TVMFFIAny* src) { + return src->type_index >= TypeIndex::kTVMFFIStaticObjectBegin && + details::IsObjectInstance(src->type_index); + } + + static TVM_FFI_INLINE const TObject* CopyFromAnyStorageAfterCheck(const TVMFFIAny* src) { + return details::ObjectUnsafe::RawObjectPtrFromUnowned(src->v_obj); + } + + static TVM_FFI_INLINE std::optional TryConvertFromAnyView(const TVMFFIAny* src) { + if (CheckAnyStorage(src)) return CopyFromAnyStorageAfterCheck(src); + return std::nullopt; + } + + static TVM_FFI_INLINE std::string TypeStr() { return TObject::_type_key; } +}; + +template +inline constexpr bool use_default_type_traits_v> = false; + +template +struct TypeTraits> : public TypeTraitsBase { + static TVM_FFI_INLINE void CopyToAnyView(const Optional& src, TVMFFIAny* result) { + if (src.has_value()) { + TypeTraits::CopyToAnyView(*src, result); + } else { + TypeTraits::CopyToAnyView(nullptr, result); + } + } + + static TVM_FFI_INLINE void MoveToAny(Optional src, TVMFFIAny* result) { + if (src.has_value()) { + TypeTraits::MoveToAny(*std::move(src), result); + } else { + TypeTraits::CopyToAnyView(nullptr, result); + } + } + + static TVM_FFI_INLINE bool CheckAnyStorage(const TVMFFIAny* src) { + if (src->type_index == TypeIndex::kTVMFFINone) return true; + return TypeTraits::CheckAnyStorage(src); + } + + static TVM_FFI_INLINE Optional CopyFromAnyStorageAfterCheck(const TVMFFIAny* src) { + if (src->type_index == TypeIndex::kTVMFFINone) { + return Optional(std::nullopt); + } + return TypeTraits::CopyFromAnyStorageAfterCheck(src); + } + + static TVM_FFI_INLINE Optional MoveFromAnyStorageAfterCheck(TVMFFIAny* src) { + if (src->type_index == TypeIndex::kTVMFFINone) { + return Optional(std::nullopt); + } + return TypeTraits::MoveFromAnyStorageAfterCheck(src); + } + + static TVM_FFI_INLINE std::optional> TryConvertFromAnyView(const TVMFFIAny* src) { + if (src->type_index == TypeIndex::kTVMFFINone) return Optional(std::nullopt); + if (std::optional opt = TypeTraits::TryConvertFromAnyView(src)) { + return Optional(*std::move(opt)); + } else { + // important to be explicit here + // because nullopt can convert to std::optional(nullopt) which indicate success + // return std::optional>(std::nullopt) to indicate failure + return std::optional>(std::nullopt); + } + } + + static TVM_FFI_INLINE std::string GetMismatchTypeInfo(const TVMFFIAny* src) { + return TypeTraits::GetMismatchTypeInfo(src); + } + + static TVM_FFI_INLINE std::string TypeStr() { + return "Optional<" + TypeTraits::TypeStr() + ">"; + } +}; +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_TYPE_TRAITS_H_ diff --git a/ffi/scripts/run_tests.sh b/ffi/scripts/run_tests.sh new file mode 100755 index 000000000000..8fc9eb95d005 --- /dev/null +++ b/ffi/scripts/run_tests.sh @@ -0,0 +1,26 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +set -euxo pipefail + +BUILD_TYPE=RelWithDebugInfo + +rm -rf build/CMakeFiles build/CMakeCache.txt +cmake -G Ninja -S . -B build -DTVM_FFI_BUILD_TESTS=ON -DCMAKE_BUILD_TYPE=${BUILD_TYPE} \ + -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DCMAKE_CXX_COMPILER_LAUNCHER=ccache +cmake --build build --parallel 16 --clean-first --config ${BUILD_TYPE} --target tvm_ffi_tests +GTEST_COLOR=1 ctest -V -C ${BUILD_TYPE} --test-dir build --output-on-failure diff --git a/ffi/src/ffi/container.cc b/ffi/src/ffi/container.cc new file mode 100644 index 000000000000..563505b74c5f --- /dev/null +++ b/ffi/src/ffi/container.cc @@ -0,0 +1,95 @@ + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* + * \file src/ffi/ffi_api.cc + * \brief Extra ffi apis for frontend to access containers. + */ +#include +#include +#include +#include + +namespace tvm { +namespace ffi { + +TVM_FFI_REGISTER_GLOBAL("ffi.Array").set_body_packed([](ffi::PackedArgs args, Any* ret) { + *ret = Array(args.data(), args.data() + args.size()); +}); + +TVM_FFI_REGISTER_GLOBAL("ffi.ArrayGetItem") + .set_body_typed([](const ffi::ArrayObj* n, int64_t i) -> Any { return n->at(i); }); + +TVM_FFI_REGISTER_GLOBAL("ffi.ArraySize").set_body_typed([](const ffi::ArrayObj* n) -> int64_t { + return static_cast(n->size()); +}); +// Map +TVM_FFI_REGISTER_GLOBAL("ffi.Map").set_body_packed([](ffi::PackedArgs args, Any* ret) { + TVM_FFI_ICHECK_EQ(args.size() % 2, 0); + Map data; + for (int i = 0; i < args.size(); i += 2) { + data.Set(args[i], args[i + 1]); + } + *ret = data; +}); + +TVM_FFI_REGISTER_GLOBAL("ffi.MapSize").set_body_typed([](const ffi::MapObj* n) -> int64_t { + return static_cast(n->size()); +}); + +TVM_FFI_REGISTER_GLOBAL("ffi.MapGetItem") + .set_body_typed([](const ffi::MapObj* n, const Any& k) -> Any { return n->at(k); }); + +TVM_FFI_REGISTER_GLOBAL("ffi.MapCount") + .set_body_typed([](const ffi::MapObj* n, const Any& k) -> int64_t { return n->count(k); }); + +// Favor struct outside function scope as MSVC may have bug for in fn scope struct. +class MapForwardIterFunctor { + public: + MapForwardIterFunctor(ffi::MapObj::iterator iter, ffi::MapObj::iterator end) + : iter_(iter), end_(end) {} + // 0 get current key + // 1 get current value + // 2 move to next: return true if success, false if end + Any operator()(int command) const { + if (command == 0) { + return (*iter_).first; + } else if (command == 1) { + return (*iter_).second; + } else { + ++iter_; + if (iter_ == end_) { + return false; + } + return true; + } + } + + private: + mutable ffi::MapObj::iterator iter_; + ffi::MapObj::iterator end_; +}; + +TVM_FFI_REGISTER_GLOBAL("ffi.MapForwardIterFunctor") + .set_body_typed([](const ffi::MapObj* n) -> ffi::Function { + return ffi::Function::FromUnpacked(MapForwardIterFunctor(n->begin(), n->end())); + }); + +} // namespace ffi +} // namespace tvm diff --git a/ffi/src/ffi/dtype.cc b/ffi/src/ffi/dtype.cc new file mode 100644 index 000000000000..7661ab4b97b1 --- /dev/null +++ b/ffi/src/ffi/dtype.cc @@ -0,0 +1,328 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include + +#include + +namespace tvm { +namespace ffi { +namespace details { +/*! + * \brief Get the custom type name for a given type code. + */ +inline String DLDataTypeCodeGetCustomTypeName(DLDataTypeCode type_code) { + static Function fget_custom_type_name = Function::GetGlobalRequired("dtype.get_custom_type_name"); + return fget_custom_type_name(static_cast(type_code)).cast(); +} + +/*! + * \brief Get the custom type name for a given type code. + * \param str The string to parse. + * \param scan The scan pointer. + * \return The custom type name. + */ +inline int ParseCustomDataTypeCode(const std::string_view& str, const char** scan) { + TVM_FFI_ICHECK(str.substr(0, 6) == "custom") << "Not a valid custom datatype string"; + auto tmp = str.data(); + TVM_FFI_ICHECK(str.data() == tmp); + *scan = str.data() + 6; + TVM_FFI_ICHECK(str.data() == tmp); + if (**scan != '[') + TVM_FFI_THROW(ValueError) << "expected opening brace after 'custom' type in" << str; + TVM_FFI_ICHECK(str.data() == tmp); + *scan += 1; + TVM_FFI_ICHECK(str.data() == tmp); + size_t custom_name_len = 0; + TVM_FFI_ICHECK(str.data() == tmp); + while (*scan + custom_name_len <= str.data() + str.length() && + *(*scan + custom_name_len) != ']') { + ++custom_name_len; + } + TVM_FFI_ICHECK(str.data() == tmp); + if (*(*scan + custom_name_len) != ']') { + TVM_FFI_THROW(ValueError) << "expected closing brace after 'custom' type in" << str; + } + TVM_FFI_ICHECK(str.data() == tmp); + *scan += custom_name_len + 1; + TVM_FFI_ICHECK(str.data() == tmp); + auto type_name = str.substr(7, custom_name_len); + TVM_FFI_ICHECK(str.data() == tmp); + static Function fget_custom_type_code = Function::GetGlobalRequired("dtype.get_custom_type_code"); + return fget_custom_type_code(std::string(type_name)).cast(); +} + +/* + * \brief Convert a DLDataTypeCode to a string. + * \param os The output stream. + * \param type_code The DLDataTypeCode to convert. + */ +inline void PrintDLDataTypeCodeAsStr(std::ostream& os, DLDataTypeCode type_code) { // NOLINT(*) + switch (static_cast(type_code)) { + case kDLInt: { + os << "int"; + break; + } + case kDLUInt: { + os << "uint"; + break; + } + case kDLFloat: { + os << "float"; + break; + } + case kDLOpaqueHandle: { + os << "handle"; + break; + } + case kDLBfloat: { + os << "bfloat"; + break; + } + case kDLFloat8_e3m4: { + os << "float8_e3m4"; + break; + } + case kDLFloat8_e4m3: { + os << "float8_e4m3"; + break; + } + case kDLFloat8_e4m3b11fnuz: { + os << "float8_e4m3b11fnuz"; + break; + } + case kDLFloat8_e4m3fn: { + os << "float8_e4m3fn"; + break; + } + case kDLFloat8_e4m3fnuz: { + os << "float8_e4m3fnuz"; + break; + } + case kDLFloat8_e5m2: { + os << "float8_e5m2"; + break; + } + case kDLFloat8_e5m2fnuz: { + os << "float8_e5m2fnuz"; + break; + } + case kDLFloat8_e8m0fnu: { + os << "float8_e8m0fnu"; + break; + } + case kDLFloat6_e2m3fn: { + os << "float6_e2m3fn"; + break; + } + case kDLFloat6_e3m2fn: { + os << "float6_e3m2fn"; + break; + } + case kDLFloat4_e2m1fn: { + os << "float4_e2m1fn"; + break; + } + default: { + if (static_cast(type_code) >= static_cast(DLExtDataTypeCode::kDLExtCustomBegin)) { + os << "custom[" << details::DLDataTypeCodeGetCustomTypeName(type_code) << "]"; + } else { + TVM_FFI_THROW(ValueError) << "DLDataType contains unknown type_code=" + << static_cast(type_code); + } + TVM_FFI_UNREACHABLE(); + } + } +} +} // namespace details + +/*! + * \brief Printer function for DLDataType. + * \param os The output stream. + * \param dtype The DLDataType to print. + * \return The output stream. + */ +inline std::string DLDataTypeToString_(DLDataType dtype) { // NOLINT(*) + if (dtype.bits == 1 && dtype.lanes == 1 && dtype.code == kDLUInt) { + return "bool"; + } + // specially handle void + if (dtype.code == kDLOpaqueHandle && dtype.lanes == 0 && dtype.bits == 0) { + return ""; + } + + std::ostringstream os; + if (dtype.code >= kDLExtCustomBegin) { + os << "custom[" + << details::DLDataTypeCodeGetCustomTypeName(static_cast(dtype.code)) << "]"; + } else { + os << details::DLDataTypeCodeAsCStr(static_cast(dtype.code)); + } + if (dtype.code == kDLOpaqueHandle) return os.str(); + int16_t lanes = static_cast(dtype.lanes); + if (dtype.code < kDLFloat8_e3m4) { + os << static_cast(dtype.bits); + } + if (lanes > 1) { + os << 'x' << lanes; + } else if (lanes < -1) { + os << "xvscalex" << -lanes; + } + return os.str(); +} + +/*! + * \brief Parse a string to a DLDataType. + * \param str The string to convert. + * \return The corresponding DLDataType. + */ +inline DLDataType StringViewToDLDataType_(std::string_view str) { + DLDataType dtype; + // handle void type + if (str.length() == 0 || str == "void") { + dtype.code = kDLOpaqueHandle; + dtype.bits = 0; + dtype.lanes = 0; + return dtype; + } + // set the default values; + dtype.bits = 32; + dtype.lanes = 1; + const char* scan; + + auto parse_float = [&](const std::string_view& str, int offset, int code, int bits) { + dtype.code = static_cast(code); + dtype.bits = static_cast(bits); + scan = str.data() + offset; + char* endpt = nullptr; + if (*scan == 'x') { + dtype.lanes = static_cast(strtoul(scan + 1, &endpt, 10)); + scan = endpt; + } + if (scan != str.data() + str.length()) { + TVM_FFI_THROW(ValueError) << "unknown dtype `" << str << '`'; + } + return dtype; + }; + + if (str.compare(0, 3, "int") == 0) { + dtype.code = kDLInt; + scan = str.data() + 3; + } else if (str.compare(0, 4, "uint") == 0) { + dtype.code = kDLUInt; + scan = str.data() + 4; + } else if (str.compare(0, 5, "float") == 0) { + if (str.compare(5, 2, "8_") == 0) { + if (str.compare(7, 4, "e3m4") == 0) { + return parse_float(str, 11, kDLFloat8_e3m4, 8); + } else if (str.compare(7, 4, "e4m3") == 0) { + if (str.compare(11, 7, "b11fnuz") == 0) { + return parse_float(str, 18, kDLFloat8_e4m3b11fnuz, 8); + } else if (str.compare(11, 2, "fn") == 0) { + if (str.compare(13, 2, "uz") == 0) { + return parse_float(str, 15, kDLFloat8_e4m3fnuz, 8); + } else { + return parse_float(str, 13, kDLFloat8_e4m3fn, 8); + } + } else { + return parse_float(str, 11, kDLFloat8_e4m3, 8); + } + } else if (str.compare(7, 8, "e5m2fnuz") == 0) { + return parse_float(str, 15, kDLFloat8_e5m2fnuz, 8); + } else if (str.compare(7, 4, "e5m2") == 0) { + return parse_float(str, 11, kDLFloat8_e5m2, 8); + } else if (str.compare(7, 7, "e8m0fnu") == 0) { + return parse_float(str, 14, kDLFloat8_e8m0fnu, 8); + } else { + TVM_FFI_THROW(ValueError) << "unknown float8 type `" << str << '`'; + TVM_FFI_UNREACHABLE(); + } + } else if (str.compare(5, 2, "6_") == 0) { + if (str.compare(7, 6, "e2m3fn") == 0) { + return parse_float(str, 13, kDLFloat6_e2m3fn, 6); + } else if (str.compare(7, 6, "e3m2fn") == 0) { + return parse_float(str, 13, kDLFloat6_e3m2fn, 6); + } else { + TVM_FFI_THROW(ValueError) << "unknown float6 type `" << str << '`'; + TVM_FFI_UNREACHABLE(); + } + } else if (str.compare(5, 2, "4_") == 0) { + // kFloat4_e2m1fn + if (str.compare(7, 6, "e2m1fn") == 0) { + return parse_float(str, 13, kDLFloat4_e2m1fn, 4); + } else { + TVM_FFI_THROW(ValueError) << "unknown float4 type `" << str << '`'; + TVM_FFI_UNREACHABLE(); + } + } else { + dtype.code = kDLFloat; + scan = str.data() + 5; + } + } else if (str.compare(0, 6, "handle") == 0) { + dtype.code = kDLOpaqueHandle; + dtype.bits = 64; // handle uses 64 bit by default. + scan = str.data() + 6; + } else if (str == "bool") { + dtype.code = kDLUInt; + dtype.bits = 1; + dtype.lanes = 1; + return dtype; + } else if (str.compare(0, 6, "bfloat") == 0) { + dtype.code = kDLBfloat; + dtype.bits = 16; + scan = str.data() + 6; + } else if (str.compare(0, 6, "custom") == 0) { + dtype.code = static_cast(details::ParseCustomDataTypeCode(str, &scan)); + } else { + scan = str.data(); + TVM_FFI_THROW(ValueError) << "unknown dtype `" << str << '`'; + } + char* xdelim; // emulate sscanf("%ux%u", bits, lanes) + uint8_t bits = static_cast(strtoul(scan, &xdelim, 10)); + if (bits != 0) dtype.bits = bits; + int scalable_multiplier = 1; + if (strncmp(xdelim, "xvscale", 7) == 0) { + scalable_multiplier = -1; + xdelim += 7; + } + char* endpt = xdelim; + if (*xdelim == 'x') { + dtype.lanes = static_cast(scalable_multiplier * strtoul(xdelim + 1, &endpt, 10)); + } + if (endpt != str.data() + str.length()) { + TVM_FFI_THROW(ValueError) << "unknown dtype `" << str << '`'; + } + return dtype; +} + +} // namespace ffi +} // namespace tvm + +int TVMFFIDataTypeFromString(const TVMFFIByteArray* str, DLDataType* out) { + TVM_FFI_SAFE_CALL_BEGIN(); + *out = tvm::ffi::StringViewToDLDataType_(std::string_view(str->data, str->size)); + TVM_FFI_SAFE_CALL_END(); +} + +int TVMFFIDataTypeToString(DLDataType dtype, TVMFFIObjectHandle* out) { + TVM_FFI_SAFE_CALL_BEGIN(); + tvm::ffi::String out_str(tvm::ffi::DLDataTypeToString_(dtype)); + *out = tvm::ffi::details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(out_str)); + TVM_FFI_SAFE_CALL_END(); +} diff --git a/ffi/src/ffi/error.cc b/ffi/src/ffi/error.cc new file mode 100644 index 000000000000..c8c77e510d60 --- /dev/null +++ b/ffi/src/ffi/error.cc @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* + * \file src/ffi/error.cc + * \brief Error handling implementation + */ +#include +#include + +namespace tvm { +namespace ffi { + +class SafeCallContext { + public: + void SetRaised(TVMFFIObjectHandle error) { + last_error_ = + details::ObjectUnsafe::ObjectPtrFromUnowned(static_cast(error)); + } + + void SetRaisedByCstr(const char* kind, const char* message, const TVMFFIByteArray* traceback) { + Error error(kind, message, traceback); + last_error_ = details::ObjectUnsafe::ObjectPtrFromObjectRef(std::move(error)); + } + + void MoveFromRaised(TVMFFIObjectHandle* result) { + result[0] = details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(last_error_)); + } + + static SafeCallContext* ThreadLocal() { + static thread_local SafeCallContext ctx; + return &ctx; + } + + private: + ObjectPtr last_error_; +}; + +} // namespace ffi +} // namespace tvm + +void TVMFFIErrorSetRaisedByCStr(const char* kind, const char* message) { + // NOTE: run traceback here to simplify the depth of tracekback + tvm::ffi::SafeCallContext::ThreadLocal()->SetRaisedByCstr(kind, message, TVM_FFI_TRACEBACK_HERE); +} + +void TVMFFIErrorSetRaised(TVMFFIObjectHandle error) { + tvm::ffi::SafeCallContext::ThreadLocal()->SetRaised(error); +} + +void TVMFFIErrorMoveFromRaised(TVMFFIObjectHandle* result) { + tvm::ffi::SafeCallContext::ThreadLocal()->MoveFromRaised(result); +} + +TVMFFIObjectHandle TVMFFIErrorCreate(const TVMFFIByteArray* kind, const TVMFFIByteArray* message, + const TVMFFIByteArray* traceback) { + TVM_FFI_LOG_EXCEPTION_CALL_BEGIN(); + tvm::ffi::Error error(std::string(kind->data, kind->size), + std::string(message->data, message->size), + std::string(traceback->data, traceback->size)); + TVMFFIObjectHandle out = + tvm::ffi::details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(error)); + return out; + TVM_FFI_LOG_EXCEPTION_CALL_END(TVMFFIErrorCreate); +} diff --git a/ffi/src/ffi/function.cc b/ffi/src/ffi/function.cc new file mode 100644 index 000000000000..ed10ea59c9bc --- /dev/null +++ b/ffi/src/ffi/function.cc @@ -0,0 +1,291 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* + * \file src/ffi/function.cc + * \brief Function call registry and safecall context + */ +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace tvm { +namespace ffi { + +/*! + * \brief Global function table. + * + + * \note We do not use mutex to guard updating of GlobalFunctionTable + * + * The assumption is that updating of GlobalFunctionTable will be done + * in the main thread during initialization or loading, or + * explicitly locked from the caller. + * + * Then the followup code will leverage the information + */ +class GlobalFunctionTable { + public: + void Update(const String& name, Function func, bool can_override) { + if (table_.count(name)) { + if (!can_override) { + TVM_FFI_THROW(RuntimeError) << "Global Function `" << name << "` is already registered"; + } + } + table_[name] = new Function(func); + } + + bool Remove(const String& name) { + auto it = table_.find(name); + if (it == table_.end()) return false; + table_.erase(it); + return true; + } + + const Function* Get(const String& name) { + auto it = table_.find(name); + if (it == table_.end()) return nullptr; + return it->second; + } + + Array ListNames() const { + Array names; + names.reserve(table_.size()); + for (const auto& kv : table_) { + names.push_back(kv.first); + } + return names; + } + + static GlobalFunctionTable* Global() { + // We deliberately create a new instance via raw new + // This is because GlobalFunctionTable can contain callbacks into + // the host language (Python) and the resource can become invalid + // indeterministic order of destruction and forking. + // The resources will only be recycled during program exit. + static GlobalFunctionTable* inst = new GlobalFunctionTable(); + return inst; + } + + private: + // deliberately track function pointer without recycling + // to avoid + std::unordered_map table_; +}; + +/*! + * \brief Execution environment specific API registry. + * + * This registry stores C API function pointers about + * execution environment(e.g. python) specific API function that + * we need for specific low-level handling(e.g. signal checking). + * + * We only stores the C API function when absolutely necessary (e.g. when signal handler + * cannot trap back into python). Always consider use the Function FFI when possible + * in other cases. + */ +class EnvCAPIRegistry { + public: + /*! + * \brief Callback to check if signals have been sent to the process and + * if so invoke the registered signal handler in the frontend environment. + * + * When running FFI in another language (Python), the signal handler + * may not be immediately executed, but instead the signal is marked + * in the interpreter state (to ensure non-blocking of the signal handler). + * + * \return 0 if no error happens, -1 if error happens. + */ + typedef int (*F_PyErr_CheckSignals)(); + + /*! \brief Callback to increment/decrement the python ref count */ + typedef void (*F_Py_IncDefRef)(void*); + + /*! + * \brief PyErr_CheckSignal function + */ + F_PyErr_CheckSignals pyerr_check_signals = nullptr; + + /*! + \brief PyGILState_Ensure function + */ + void* (*py_gil_state_ensure)() = nullptr; + + /*! + \brief PyGILState_Release function + */ + void (*py_gil_state_release)(void*) = nullptr; + + static EnvCAPIRegistry* Global() { + static EnvCAPIRegistry* inst = new EnvCAPIRegistry(); + return inst; + } + + // register environment(e.g. python) specific api functions + void Register(const std::string& symbol_name, void* fptr) { + if (symbol_name == "PyErr_CheckSignals") { + Update(symbol_name, &pyerr_check_signals, fptr); + } else if (symbol_name == "PyGILState_Ensure") { + Update(symbol_name, &py_gil_state_ensure, fptr); + } else if (symbol_name == "PyGILState_Release") { + Update(symbol_name, &py_gil_state_release, fptr); + } else { + TVM_FFI_THROW(ValueError) << "Unknown env API " + symbol_name; + } + } + + // implementation of tvm::runtime::EnvCheckSignals + int EnvCheckSignals() { + // check python signal to see if there are exception raised + if (pyerr_check_signals != nullptr) { + // The C++ env comes without gil, so we need to grab gil here + WithGIL context(this); + if ((*pyerr_check_signals)() != 0) { + // The error will let FFI know that the frontend environment + // already set an error. + return -1; + } + } + return 0; + } + + private: + // update the internal API table + template + void Update(const String& symbol_name, FType* target, void* ptr) { + FType ptr_casted = reinterpret_cast(ptr); + target[0] = ptr_casted; + } + + struct WithGIL { + explicit WithGIL(EnvCAPIRegistry* self) : self(self) { + TVM_FFI_ICHECK(self->py_gil_state_ensure); + TVM_FFI_ICHECK(self->py_gil_state_release); + gil_state = self->py_gil_state_ensure(); + } + ~WithGIL() { + if (self && gil_state) { + self->py_gil_state_release(gil_state); + } + } + WithGIL(const WithGIL&) = delete; + WithGIL(WithGIL&&) = delete; + WithGIL& operator=(const WithGIL&) = delete; + WithGIL& operator=(WithGIL&&) = delete; + + EnvCAPIRegistry* self = nullptr; + void* gil_state = nullptr; + }; +}; +} // namespace ffi +} // namespace tvm + +int TVMFFIFunctionCreate(void* self, TVMFFISafeCallType safe_call, void (*deleter)(void* self), + TVMFFIObjectHandle* out) { + TVM_FFI_SAFE_CALL_BEGIN(); + tvm::ffi::Function func = tvm::ffi::Function::FromExternC(self, safe_call, deleter); + *out = tvm::ffi::details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(func)); + TVM_FFI_SAFE_CALL_END(); +} + +int TVMFFIAnyViewToOwnedAny(const TVMFFIAny* any_view, TVMFFIAny* out) { + TVM_FFI_SAFE_CALL_BEGIN(); + tvm::ffi::Any result(*reinterpret_cast(any_view)); + *out = tvm::ffi::details::AnyUnsafe::MoveAnyToTVMFFIAny(std::move(result)); + TVM_FFI_SAFE_CALL_END(); +} + +int TVMFFIFunctionSetGlobal(const TVMFFIByteArray* name, TVMFFIObjectHandle f, int override) { + using namespace tvm::ffi; + TVM_FFI_SAFE_CALL_BEGIN(); + String name_str(name->data, name->size); + GlobalFunctionTable::Global()->Update(name_str, GetRef(static_cast(f)), + override != 0); + TVM_FFI_SAFE_CALL_END(); +} + +int TVMFFIFunctionGetGlobal(const TVMFFIByteArray* name, TVMFFIObjectHandle* out) { + using namespace tvm::ffi; + TVM_FFI_SAFE_CALL_BEGIN(); + String name_str(name->data, name->size); + const Function* fp = GlobalFunctionTable::Global()->Get(name_str); + if (fp != nullptr) { + tvm::ffi::Function func(*fp); + *out = tvm::ffi::details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(func)); + } else { + *out = nullptr; + } + TVM_FFI_SAFE_CALL_END(); +} + +int TVMFFIFunctionCall(TVMFFIObjectHandle func, TVMFFIAny* args, int32_t num_args, + TVMFFIAny* result) { + using namespace tvm::ffi; + // NOTE: this is a tail call + return reinterpret_cast(func)->safe_call(func, args, num_args, result); +} + +int TVMFFIEnvCheckSignals() { return tvm::ffi::EnvCAPIRegistry::Global()->EnvCheckSignals(); } + +/*! + * \brief Register a symbol into the from the surrounding env. + * \param name The name of the symbol. + * \param symbol The symbol to register. + * \return 0 when success, nonzero when failure happens + */ +int TVMFFIEnvRegisterCAPI(const TVMFFIByteArray* name, void* symbol) { + TVM_FFI_SAFE_CALL_BEGIN(); + std::string s_name(name->data, name->size); + tvm::ffi::EnvCAPIRegistry::Global()->Register(s_name, symbol); + TVM_FFI_SAFE_CALL_END(); +} + +TVM_FFI_REGISTER_GLOBAL("ffi.FunctionRemoveGlobal") + .set_body_typed([](const tvm::ffi::String& name) -> bool { + return tvm::ffi::GlobalFunctionTable::Global()->Remove(name); + }); + +TVM_FFI_REGISTER_GLOBAL("ffi.FunctionListGlobalNamesFunctor").set_body_typed([]() { + // NOTE: we return functor instead of array + // so list global function names do not need to depend on array + // this is because list global function names usually is a core api that happens + // before array ffi functions are available. + tvm::ffi::Array names = tvm::ffi::GlobalFunctionTable::Global()->ListNames(); + auto return_functor = [names](int64_t i) -> tvm::ffi::Any { + if (i < 0) { + return names.size(); + } else { + return names[i]; + } + }; + return tvm::ffi::Function::FromUnpacked(return_functor); +}); + +TVM_FFI_REGISTER_GLOBAL("ffi.String").set_body_typed([](tvm::ffi::String val) -> tvm::ffi::String { + return val; +}); + +TVM_FFI_REGISTER_GLOBAL("ffi.Bytes").set_body_typed([](tvm::ffi::Bytes val) -> tvm::ffi::Bytes { + return val; +}); diff --git a/ffi/src/ffi/ndarray.cc b/ffi/src/ffi/ndarray.cc new file mode 100644 index 000000000000..d4c1470566bf --- /dev/null +++ b/ffi/src/ffi/ndarray.cc @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* + * \file src/ffi/ndarray.cc + * \brief NDArray C API implementation + */ +#include +#include +#include + +namespace tvm { +namespace ffi { + +// Shape +TVM_FFI_REGISTER_GLOBAL("ffi.Shape").set_body_packed([](ffi::PackedArgs args, Any* ret) { + int64_t* mutable_data; + ObjectPtr shape = details::MakeEmptyShape(args.size(), &mutable_data); + for (int i = 0; i < args.size(); ++i) { + if (auto opt_int = args[i].as()) { + mutable_data[i] = *opt_int; + } else { + TVM_FFI_THROW(ValueError) << "Expect shape to take list of int arguments"; + } + } + *ret = Shape(shape); +}); +} // namespace ffi +} // namespace tvm + +int TVMFFINDArrayFromDLPack(DLManagedTensor* from, int32_t min_alignment, + int32_t require_contiguous, TVMFFIObjectHandle* out) { + TVM_FFI_SAFE_CALL_BEGIN(); + tvm::ffi::NDArray nd = + tvm::ffi::NDArray::FromDLPack(from, static_cast(min_alignment), require_contiguous); + *out = tvm::ffi::details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(nd)); + TVM_FFI_SAFE_CALL_END(); +} + +int TVMFFINDArrayFromDLPackVersioned(DLManagedTensorVersioned* from, int32_t min_alignment, + int32_t require_contiguous, TVMFFIObjectHandle* out) { + TVM_FFI_SAFE_CALL_BEGIN(); + tvm::ffi::NDArray nd = tvm::ffi::NDArray::FromDLPackVersioned( + from, static_cast(min_alignment), require_contiguous); + *out = tvm::ffi::details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(nd)); + TVM_FFI_SAFE_CALL_END(); +} + +int TVMFFINDArrayToDLPack(TVMFFIObjectHandle from, DLManagedTensor** out) { + TVM_FFI_SAFE_CALL_BEGIN(); + *out = tvm::ffi::details::ObjectUnsafe::RawObjectPtrFromUnowned( + static_cast(from)) + ->ToDLPack(); + TVM_FFI_SAFE_CALL_END(); +} + +int TVMFFINDArrayToDLPackVersioned(TVMFFIObjectHandle from, DLManagedTensorVersioned** out) { + TVM_FFI_SAFE_CALL_BEGIN(); + *out = tvm::ffi::details::ObjectUnsafe::RawObjectPtrFromUnowned( + static_cast(from)) + ->ToDLPackVersioned(); + TVM_FFI_SAFE_CALL_END(); +} diff --git a/ffi/src/ffi/object.cc b/ffi/src/ffi/object.cc new file mode 100644 index 000000000000..63ec68790e57 --- /dev/null +++ b/ffi/src/ffi/object.cc @@ -0,0 +1,312 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* + * \file src/ffi/object.cc + * \brief Registry to record dynamic types + */ +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace tvm { +namespace ffi { + +/*! + * \brief Global registry that manages + * + * \note We do not use mutex to guard updating of TypeTable + * + * The assumption is that updating of TypeTable will be done + * in the main thread during initialization or loading, or + * explicitly locked from the caller. + * + * Then the followup code will leverage the information + */ +class TypeTable { + public: + /*! \brief Type information */ + struct Entry : public TypeInfo { + /*! \brief stored type key */ + std::string type_key_data; + /*! \brief acenstor information */ + std::vector type_acenstors_data; + /*! \brief type fields informaton */ + std::vector type_fields_data; + // NOTE: the indices in [index, index + num_reserved_slots) are + // reserved for the child-class of this type. + /*! \brief Total number of slots reserved for the type and its children. */ + int32_t num_slots; + /*! \brief number of allocated child slots. */ + int32_t allocated_slots; + /*! \brief Whether child can overflow. */ + bool child_slots_can_overflow{true}; + + Entry(int32_t type_index, int32_t type_depth, std::string type_key, int32_t num_slots, + bool child_slots_can_overflow, const Entry* parent) { + // setup fields in the class + this->type_key_data = std::move(type_key); + this->num_slots = num_slots; + this->allocated_slots = 1; + this->child_slots_can_overflow = child_slots_can_overflow; + // set up type acenstors information + if (type_depth != 0) { + TVM_FFI_ICHECK_NOTNULL(parent); + TVM_FFI_ICHECK_EQ(type_depth, parent->type_depth + 1); + type_acenstors_data.resize(type_depth); + // copy over parent's type information + for (int32_t i = 0; i < parent->type_depth; ++i) { + type_acenstors_data[i] = parent->type_acenstors[i]; + } + // set last type information to be parent + type_acenstors_data[parent->type_depth] = parent->type_index; + } + // initialize type info: no change to type_key and type_acenstors fields + // after this line + this->type_index = type_index; + this->type_depth = type_depth; + this->type_key = TVMFFIByteArray{this->type_key_data.data(), this->type_key_data.length()}; + this->type_key_hash = std::hash()(this->type_key_data); + this->type_acenstors = type_acenstors_data.data(); + // initialize the reflection information + this->num_fields = 0; + this->num_methods = 0; + this->fields = nullptr; + this->methods = nullptr; + } + }; + + int32_t GetOrAllocTypeIndex(std::string type_key, int32_t static_type_index, int32_t type_depth, + int32_t num_child_slots, bool child_slots_can_overflow, + int32_t parent_type_index) { + auto it = type_key2index_.find(type_key); + if (it != type_key2index_.end()) { + return type_table_[it->second]->type_index; + } + + // get parent's entry + Entry* parent = [&]() -> Entry* { + if (parent_type_index < 0) return nullptr; + // try to allocate from parent's type table. + TVM_FFI_ICHECK_LT(parent_type_index, type_table_.size()) + << " type_key=" << type_key << ", static_index=" << static_type_index; + return type_table_[parent_type_index].get(); + }(); + + // get allocated index + int32_t allocated_tindex = [&]() { + // Step 0: static allocation + if (static_type_index >= 0) { + TVM_FFI_ICHECK_LT(static_type_index, type_table_.size()); + TVM_FFI_ICHECK(type_table_[static_type_index] == nullptr) + << "Conflicting static index " << static_type_index << " between " + << ToStringView(type_table_[static_type_index]->type_key) << " and " << type_key; + return static_type_index; + } + TVM_FFI_ICHECK_NOTNULL(parent); + int num_slots = num_child_slots + 1; + if (parent->allocated_slots + num_slots <= parent->num_slots) { + // allocate the slot from parent's reserved pool + int32_t allocated_tindex = parent->type_index + parent->allocated_slots; + // update parent's state + parent->allocated_slots += num_slots; + return allocated_tindex; + } + // Step 2: allocate from overflow + TVM_FFI_ICHECK(parent->child_slots_can_overflow) + << "Reach maximum number of sub-classes for " << ToStringView(parent->type_key); + // allocate new entries. + int32_t allocated_tindex = type_counter_; + type_counter_ += num_slots; + TVM_FFI_ICHECK_LE(type_table_.size(), type_counter_); + type_table_.reserve(type_counter_); + // resize type table + while (static_cast(type_table_.size()) < type_counter_) { + type_table_.emplace_back(nullptr); + } + return allocated_tindex; + }(); + + // if parent cannot overflow, then this class cannot. + if (parent != nullptr && !(parent->child_slots_can_overflow)) { + child_slots_can_overflow = false; + } + // total number of slots include the type itself. + + if (parent != nullptr) { + TVM_FFI_ICHECK_GT(allocated_tindex, parent->type_index); + } + + type_table_[allocated_tindex] = + std::make_unique(allocated_tindex, type_depth, type_key, num_child_slots + 1, + child_slots_can_overflow, parent); + // update the key2index mapping. + type_key2index_[type_key] = allocated_tindex; + return allocated_tindex; + } + + int32_t TypeKeyToIndex(const TVMFFIByteArray* type_key) { + std::string type_key_str(type_key->data, type_key->size); + auto it = type_key2index_.find(type_key_str); + TVM_FFI_ICHECK(it != type_key2index_.end()) << "Cannot find type `" << type_key_str << "`"; + return it->second; + } + + Entry* GetTypeEntry(int32_t type_index) { + Entry* entry = nullptr; + if (type_index >= 0 && static_cast(type_index) < type_table_.size()) { + entry = type_table_[type_index].get(); + } + TVM_FFI_ICHECK(entry != nullptr) << "Cannot find type info for type_index=" << type_index; + return entry; + } + + void RegisterTypeField(int32_t type_index, const TVMFFIFieldInfo* info) { + Entry* entry = GetTypeEntry(type_index); + TVMFFIFieldInfo field_data = *info; + field_data.name = this->CopyString(info->name); + entry->type_fields_data.push_back(field_data); + // refresh ptr as the data can change + entry->fields = entry->type_fields_data.data(); + entry->num_fields = static_cast(entry->type_fields_data.size()); + } + + void Dump(int min_children_count) { + std::vector num_children(type_table_.size(), 0); + // expected child slots compute the expected slots + // based on the current child slot setting + std::vector expected_child_slots(type_table_.size(), 0); + // reverse accumulation so we can get total counts in a bottom-up manner. + for (auto it = type_table_.rbegin(); it != type_table_.rend(); ++it) { + const Entry* ptr = it->get(); + if (ptr != nullptr && ptr->type_depth != 0) { + int parent_index = ptr->type_acenstors[ptr->type_depth - 1]; + num_children[parent_index] += num_children[ptr->type_index] + 1; + if (expected_child_slots[ptr->type_index] + 1 < ptr->num_slots) { + expected_child_slots[ptr->type_index] = ptr->num_slots - 1; + } + expected_child_slots[parent_index] += expected_child_slots[ptr->type_index] + 1; + } + } + + for (const auto& ptr : type_table_) { + if (ptr != nullptr && num_children[ptr->type_index] >= min_children_count) { + std::cerr << '[' << ptr->type_index << "]\t" << ToStringView(ptr->type_key); + if (ptr->type_depth != 0) { + int32_t parent_index = ptr->type_acenstors[ptr->type_depth - 1]; + std::cerr << "\tparent=" << ToStringView(type_table_[parent_index]->type_key); + } else { + std::cerr << "\tparent=root"; + } + std::cerr << "\tnum_child_slots=" << ptr->num_slots - 1 + << "\tnum_children=" << num_children[ptr->type_index] + << "\texpected_child_slots=" << expected_child_slots[ptr->type_index] + << std::endl; + } + } + } + + static TypeTable* Global() { + static TypeTable inst; + return &inst; + } + + private: + TypeTable() { + type_table_.reserve(TypeIndex::kTVMFFIDynObjectBegin); + for (int32_t i = 0; i < TypeIndex::kTVMFFIDynObjectBegin; ++i) { + type_table_.emplace_back(nullptr); + } + // initialize the entry for object + this->GetOrAllocTypeIndex(Object::_type_key, Object::_type_index, Object::_type_depth, + Object::_type_child_slots, Object::_type_child_slots_can_overflow, + -1); + // reserve the static types + ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFINone, TypeIndex::kTVMFFINone); + ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFIInt, TypeIndex::kTVMFFIInt); + ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFIFloat, TypeIndex::kTVMFFIFloat); + ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFIBool, TypeIndex::kTVMFFIBool); + ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFIRawStr, TypeIndex::kTVMFFIRawStr); + ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFIOpaquePtr, TypeIndex::kTVMFFIOpaquePtr); + ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFIDataType, TypeIndex::kTVMFFIDataType); + ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFIDevice, TypeIndex::kTVMFFIDevice); + ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFIByteArrayPtr, TypeIndex::kTVMFFIByteArrayPtr); + ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFIObjectRValueRef, + TypeIndex::kTVMFFIObjectRValueRef); + // no need to reserve for object types as they will be registered + } + + void ReserveBuiltinTypeIndex(const char* type_key, int32_t static_type_index) { + this->GetOrAllocTypeIndex(type_key, static_type_index, 0, 0, false, -1); + } + + TVMFFIByteArray CopyString(TVMFFIByteArray str) { + std::unique_ptr val = std::make_unique(str.data, str.size); + TVMFFIByteArray c_val{val->data(), val->length()}; + string_pool_.emplace_back(std::move(val)); + return c_val; + } + + int32_t type_counter_{TypeIndex::kTVMFFIDynObjectBegin}; + std::vector> type_table_; + std::unordered_map type_key2index_; + std::vector> string_pool_; +}; +} // namespace ffi +} // namespace tvm + +int TVMFFIObjectFree(TVMFFIObjectHandle handle) { + TVM_FFI_SAFE_CALL_BEGIN(); + tvm::ffi::details::ObjectUnsafe::DecRefObjectHandle(handle); + TVM_FFI_SAFE_CALL_END(); +} + +int TVMFFITypeKeyToIndex(const TVMFFIByteArray* type_key, int32_t* out_tindex) { + TVM_FFI_SAFE_CALL_BEGIN(); + out_tindex[0] = tvm::ffi::TypeTable::Global()->TypeKeyToIndex(type_key); + TVM_FFI_SAFE_CALL_END(); +} + +int TVMFFIRegisterTypeField(int32_t type_index, const TVMFFIFieldInfo* info) { + TVM_FFI_SAFE_CALL_BEGIN(); + tvm::ffi::TypeTable::Global()->RegisterTypeField(type_index, info); + TVM_FFI_SAFE_CALL_END(); +} + +int32_t TVMFFIGetOrAllocTypeIndex(const TVMFFIByteArray* type_key, int32_t static_type_index, + int32_t type_depth, int32_t num_child_slots, + int32_t child_slots_can_overflow, int32_t parent_type_index) { + TVM_FFI_LOG_EXCEPTION_CALL_BEGIN(); + std::string s_type_key = std::string(type_key->data, type_key->size); + return tvm::ffi::TypeTable::Global()->GetOrAllocTypeIndex( + s_type_key, static_type_index, type_depth, num_child_slots, child_slots_can_overflow, + parent_type_index); + TVM_FFI_LOG_EXCEPTION_CALL_END(TVMFFIGetOrAllocTypeIndex); +} + +const TVMFFITypeInfo* TVMFFIGetTypeInfo(int32_t type_index) { + TVM_FFI_LOG_EXCEPTION_CALL_BEGIN(); + return tvm::ffi::TypeTable::Global()->GetTypeEntry(type_index); + TVM_FFI_LOG_EXCEPTION_CALL_END(TVMFFIGetTypeInfo); +} diff --git a/ffi/src/ffi/testing.cc b/ffi/src/ffi/testing.cc new file mode 100644 index 000000000000..050ac28c476e --- /dev/null +++ b/ffi/src/ffi/testing.cc @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +// This file is used for testing the FFI API. +#include + +#include +#include +#include + +namespace tvm { +namespace ffi { + +void TestRaiseError(String kind, String msg) { + throw ffi::Error(kind, msg, TVM_FFI_TRACEBACK_HERE); +} + +TVM_FFI_REGISTER_GLOBAL("testing.test_raise_error").set_body_typed(TestRaiseError); + +TVM_FFI_REGISTER_GLOBAL("testing.nop").set_body_packed([](PackedArgs args, Any* ret) { + *ret = args[0]; +}); + +TVM_FFI_REGISTER_GLOBAL("testing.echo").set_body_packed([](PackedArgs args, Any* ret) { + *ret = args[0]; +}); + +void TestApply(Function f, PackedArgs args, Any* ret) { f.CallPacked(args, ret); } + +TVM_FFI_REGISTER_GLOBAL("testing.apply").set_body_packed([](PackedArgs args, Any* ret) { + auto f = args[0].cast(); + TestApply(f, args.Slice(1), ret); +}); + +TVM_FFI_REGISTER_GLOBAL("testing.run_check_signal").set_body_typed([](int nsec) { + for (int i = 0; i < nsec; ++i) { + if (TVMFFIEnvCheckSignals() != 0) { + throw ffi::EnvErrorAlreadySet(); + } + std::this_thread::sleep_for(std::chrono::seconds(1)); + } + std::cout << "Function finished without catching signal" << std::endl; +}); + +TVM_FFI_REGISTER_GLOBAL("testing.object_use_count").set_body_typed([](const Object* obj) { + return obj->use_count(); +}); + +} // namespace ffi +} // namespace tvm diff --git a/ffi/src/ffi/traceback.cc b/ffi/src/ffi/traceback.cc new file mode 100644 index 000000000000..4c7fb25427c0 --- /dev/null +++ b/ffi/src/ffi/traceback.cc @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file traceback.cc + * \brief Traceback implementation on non-windows platforms + * \note We use the term "traceback" to be consistent with python naming convention. + */ +#ifndef _MSC_VER + +#include "./traceback.h" + +#include +#include + +#if TVM_FFI_USE_LIBBACKTRACE + +#include +#include + +#include +#include +#include +#include + +#if TVM_FFI_BACKTRACE_ON_SEGFAULT +#include +#endif + +namespace tvm { +namespace ffi { +namespace { + +void BacktraceCreateErrorCallback(void*, const char* msg, int) { + std::cerr << "Could not initialize backtrace state: " << msg << std::endl; +} + +backtrace_state* BacktraceCreate() { + return backtrace_create_state(nullptr, 1, BacktraceCreateErrorCallback, nullptr); +} + +static backtrace_state* _bt_state = BacktraceCreate(); + +std::string DemangleName(std::string name) { + int status = 0; + size_t length = name.size(); + std::unique_ptr demangled_name = { + abi::__cxa_demangle(name.c_str(), nullptr, &length, &status), &std::free}; + if (demangled_name && status == 0 && length > 0) { + return demangled_name.get(); + } else { + return name; + } +} + +void BacktraceErrorCallback(void*, const char*, int) { + // do nothing +} + +void BacktraceSyminfoCallback(void* data, uintptr_t pc, const char* symname, uintptr_t, uintptr_t) { + auto str = reinterpret_cast(data); + + if (symname != nullptr) { + *str = DemangleName(symname); + } else { + std::ostringstream s; + s << "0x" << std::setfill('0') << std::setw(sizeof(uintptr_t) * 2) << std::hex << pc; + *str = s.str(); + } +} + +int BacktraceFullCallback(void* data, uintptr_t pc, const char* filename, int lineno, + const char* symbol) { + auto stack_trace = reinterpret_cast(data); + std::string symbol_str = ""; + if (symbol) { + symbol_str = DemangleName(symbol); + } else { + // see if syminfo gives anything + backtrace_syminfo(_bt_state, pc, BacktraceSyminfoCallback, BacktraceErrorCallback, &symbol_str); + } + symbol = symbol_str.data(); + + if (stack_trace->ExceedTracebackLimit()) { + return 1; + } + if (ShouldStopTraceback(filename, symbol)) { + return 1; + } + if (ShouldExcludeFrame(filename, symbol)) { + return 0; + } + stack_trace->Append(filename, symbol, lineno); + return 0; +} + +std::string Traceback() { + TracebackStorage traceback; + + if (_bt_state == nullptr) { + return ""; + } + // libbacktrace eats memory if run on multiple threads at the same time, so we guard against it + { + static std::mutex m; + std::lock_guard lock(m); + backtrace_full(_bt_state, 0, BacktraceFullCallback, BacktraceErrorCallback, &traceback); + } + return traceback.GetTraceback(); +} + +#if TVM_FFI_BACKTRACE_ON_SEGFAULT +void backtrace_handler(int sig) { + // Technically we shouldn't do any allocation in a signal handler, but + // Backtrace may allocate. What's the worst it could do? We're already + // crashing. + std::cerr << "!!!!!!! TVM FFI encountered a Segfault !!!!!!!\n" << Traceback() << std::endl; + + // Re-raise signal with default handler + struct sigaction act; + std::memset(&act, 0, sizeof(struct sigaction)); + act.sa_flags = SA_RESETHAND; + act.sa_handler = SIG_DFL; + sigaction(sig, &act, nullptr); + raise(sig); +} + +__attribute__((constructor)) void install_signal_handler(void) { + // this may override already installed signal handlers + std::signal(SIGSEGV, backtrace_handler); +} +#endif // TVM_FFI_BACKTRACE_ON_SEGFAULT +} // namespace +} // namespace ffi +} // namespace tvm + +const TVMFFIByteArray* TVMFFITraceback(const char*, int, const char*) { + static thread_local std::string traceback_str; + static thread_local TVMFFIByteArray traceback_array; + traceback_str = ::tvm::ffi::Traceback(); + traceback_array.data = traceback_str.data(); + traceback_array.size = traceback_str.size(); + return &traceback_array; +} +#else +// fallback implementation simply print out the last trace +const TVMFFIByteArray* TVMFFITraceback(const char* filename, int lineno, const char* func) { + static thread_local std::string traceback_str; + static thread_local TVMFFIByteArray traceback_array; + std::ostringstream traceback_stream; + // python style backtrace + traceback_stream << " File \"" << filename << "\", line " << lineno << ", in " << func << '\n'; + traceback_str = traceback_stream.str(); + traceback_array.data = traceback_str.data(); + traceback_array.size = traceback_str.size(); + return &traceback_array; +} +#endif // TVM_FFI_USE_LIBBACKTRACE +#endif // _MSC_VER diff --git a/ffi/src/ffi/traceback.h b/ffi/src/ffi/traceback.h new file mode 100644 index 000000000000..0c07361fb503 --- /dev/null +++ b/ffi/src/ffi/traceback.h @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file traceback.h + * \brief Common headers for traceback. + * \note We use the term "traceback" to be consistent with python naming convention. + */ +#ifndef TVM_FFI_TRACEBACK_H_ +#define TVM_FFI_TRACEBACK_H_ + +#include +#include +#include +#include + +namespace tvm { +namespace ffi { + +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 4996) // std::getenv is unsafe +#endif + +inline int32_t GetTracebackLimit() { + if (const char* env = std::getenv("TVM_TRACEBACK_LIMIT")) { + return std::stoi(env); + } + return 512; +} + +#ifdef _MSC_VER +#pragma warning(pop) +#endif + +/*! + * \brief List frame patterns that should be excluded as they contain less information + */ +inline bool ShouldExcludeFrame(const char* filename, const char* symbol) { + if (filename) { + // Stack frames for TVM FFI + if (strstr(filename, "include/tvm/ffi/error.h")) { + return true; + } + if (strstr(filename, "include/tvm/ffi/function_details.h")) { + return true; + } + if (strstr(filename, "include/tvm/ffi/function.h")) { + return true; + } + if (strstr(filename, "include/tvm/ffi/any.h")) { + return true; + } + if (strstr(filename, "include/tvm/runtime/logging.h")) { + return true; + } + if (strstr(filename, "src/ffi/traceback.cc")) { + return true; + } + // C++ stdlib frames + if (strstr(filename, "include/c++/")) { + return true; + } + } + + if (symbol) { + // C++ stdlib frames + if (strstr(symbol, "__libc_")) { + return true; + } + } + if (strncmp(symbol, "TVMFFIErrorSetRaisedByCStr", 26) == 0) { + return true; + } + // libffi.so stack frames. These may also show up as numeric + // addresses with no symbol name. This could be improved in the + // future by using dladdr() to check whether an address is contained + // in libffi.so + if (strstr(symbol, "ffi_call_")) { + return true; + } + return false; +} + +/** + * \brief List frames that should stop the traceback. + * \param filename The filename of the frame. + * \param symbol The symbol name of the frame. + * \return true if the frame should stop the traceback. + * \note We stop traceback at the FFI boundary. + */ +inline bool ShouldStopTraceback(const char* filename, const char* symbol) { + if (symbol != nullptr) { + if (strncmp(symbol, "TVMFFIFunctionCall", 14) == 0) { + return true; + } + // Python interpreter stack frames + // we stop traceback at the Python interpreter stack frames + // since these frame will be handled from by the python side. + if (strncmp(symbol, "_Py", 3) == 0 || strncmp(symbol, "PyObject", 9) == 0) { + return true; + } + } + return false; +} + +/*! + * \brief storage to store traceback + */ +struct TracebackStorage { + std::vector lines; + /*! \brief Maximum size of the traceback. */ + size_t max_frame_size = GetTracebackLimit(); + + void Append(const char* filename, const char* func, int lineno) { + // skip frames with empty filename + if (filename == nullptr) { + if (func != nullptr) { + if (strncmp(func, "0x0", 3) == 0) { + return; + } + filename = ""; + } else { + return; + } + } + std::ostringstream trackeback_stream; + trackeback_stream << " File \"" << filename << "\""; + if (lineno != 0) { + trackeback_stream << ", line " << lineno; + } + trackeback_stream << ", in " << func << '\n'; + lines.push_back(trackeback_stream.str()); + } + + bool ExceedTracebackLimit() const { return lines.size() >= max_frame_size; } + + // get traceback in the order of most recent call last + std::string GetTraceback() const { + std::string traceback; + for (auto it = lines.rbegin(); it != lines.rend(); ++it) { + traceback.insert(traceback.end(), it->begin(), it->end()); + } + return traceback; + } +}; + +} // namespace ffi +} // namespace tvm + +#endif // TVM_FFI_TRACEBACK_H_ diff --git a/ffi/src/ffi/traceback_win.cc b/ffi/src/ffi/traceback_win.cc new file mode 100644 index 000000000000..8278de1d77cf --- /dev/null +++ b/ffi/src/ffi/traceback_win.cc @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file traceback_win.cc + * \brief Traceback implementation on windows platform + * \note We use the term "traceback" to be consistent with python naming convention. + */ +#ifdef _MSC_VER + +// clang-format off +#include +#include // NOLINT(*) +// clang-format on + +#include +#include + +#include +#include + +#include "./traceback.h" + +namespace tvm { +namespace ffi { +namespace { + +std::string Traceback() { + TracebackStorage traceback; + HANDLE process = GetCurrentProcess(); + HANDLE thread = GetCurrentThread(); + + SymSetOptions(SYMOPT_LOAD_LINES | SYMOPT_UNDNAME); + SymInitialize(process, NULL, TRUE); + CONTEXT context = {}; + RtlCaptureContext(&context); + + STACKFRAME64 stack = {}; + DWORD machine_type; + +#if defined(_M_X64) + machine_type = IMAGE_FILE_MACHINE_AMD64; + stack.AddrPC.Offset = context.Rip; + stack.AddrFrame.Offset = context.Rbp; + stack.AddrStack.Offset = context.Rsp; +#elif defined(_M_IX86) + machine_type = IMAGE_FILE_MACHINE_I386; + stack.AddrPC.Offset = context.Eip; + stack.AddrFrame.Offset = context.Ebp; + stack.AddrStack.Offset = context.Esp; +#else +#error "Platform not supported!" +#endif + + stack.AddrPC.Mode = AddrModeFlat; + stack.AddrFrame.Mode = AddrModeFlat; + stack.AddrStack.Mode = AddrModeFlat; + + while (!traceback.ExceedTracebackLimit()) { + if (!StackWalk64(machine_type, process, thread, &stack, &context, nullptr, + SymFunctionTableAccess64, SymGetModuleBase64, nullptr)) { + break; + } + + if (stack.AddrPC.Offset == 0) { + break; + } + const char* filename = nullptr; + const char* symbol = ""; + int lineno = 0; + // Get file and line number + IMAGEHLP_LINE64 line_info; + ZeroMemory(&line_info, sizeof(IMAGEHLP_LINE64)); + line_info.SizeOfStruct = sizeof(IMAGEHLP_LINE64); + DWORD displacement32 = 0; + + if (SymGetLineFromAddr64(process, stack.AddrPC.Offset, &displacement32, &line_info)) { + filename = line_info.FileName; + lineno = line_info.LineNumber; + } + // allocate symbol info that aligns to the SYMBOL_INFO + // we use u64 here to be safe + size_t total_symbol_bytes = sizeof(SYMBOL_INFO) + MAX_SYM_NAME * sizeof(TCHAR); + size_t total_u64_words = (total_symbol_bytes + 7) / 8; + static_assert(8 % alignof(SYMBOL_INFO) == 0); + std::vector symbol_buffer(total_u64_words, 0); + PSYMBOL_INFO symbol_info = reinterpret_cast(symbol_buffer.data()); + symbol_info->SizeOfStruct = sizeof(SYMBOL_INFO); + symbol_info->MaxNameLen = MAX_SYM_NAME; + DWORD64 displacement = 0; + if (SymFromAddr(process, stack.AddrPC.Offset, &displacement, symbol_info)) { + symbol = symbol_info->Name; + } + + if (ShouldStopTraceback(filename, symbol)) { + break; + } + if (ShouldExcludeFrame(filename, symbol)) { + continue; + } + traceback.Append(filename, symbol, lineno); + } + SymCleanup(process); + return traceback.GetTraceback(); +} +} // namespace +} // namespace ffi +} // namespace tvm + +const TVMFFIByteArray* TVMFFITraceback(const char* filename, int lineno, const char* func) { + static thread_local std::string traceback_str; + static thread_local TVMFFIByteArray traceback_array; + traceback_str = ::tvm::ffi::Traceback(); + traceback_array.data = traceback_str.data(); + traceback_array.size = traceback_str.size(); + return &traceback_array; +} +#endif // _MSC_VER diff --git a/ffi/tests/cpp/CMakeLists.txt b/ffi/tests/cpp/CMakeLists.txt new file mode 100644 index 000000000000..429683600bf9 --- /dev/null +++ b/ffi/tests/cpp/CMakeLists.txt @@ -0,0 +1,26 @@ +file(GLOB _test_sources "${CMAKE_CURRENT_SOURCE_DIR}/test*.cc") +add_executable( + tvm_ffi_tests + EXCLUDE_FROM_ALL + ${_test_sources} +) +set_target_properties( + tvm_ffi_tests PROPERTIES + CXX_STANDARD 17 + CXX_STANDARD_REQUIRED ON + CXX_EXTENSIONS OFF + MSVC_RUNTIME_LIBRARY "MultiThreaded$<$:Debug>" + ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" + LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" + RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin" +) +add_cxx_warning(tvm_ffi_tests) +add_sanitizer_address(tvm_ffi_tests) +add_dsymutil(tvm_ffi_tests) +add_msvc_flags(tvm_ffi_tests) +target_link_libraries(tvm_ffi_tests PRIVATE tvm_ffi_shared) +add_googletest(tvm_ffi_tests) + +if (MSVC) + target_link_options(tvm_ffi_tests PRIVATE /DEBUG) +endif() diff --git a/ffi/tests/cpp/test_any.cc b/ffi/tests/cpp/test_any.cc new file mode 100644 index 000000000000..816ae28e0e9c --- /dev/null +++ b/ffi/tests/cpp/test_any.cc @@ -0,0 +1,342 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include + +#include "./testing_object.h" + +namespace { + +using namespace tvm::ffi; +using namespace tvm::ffi::testing; + +TEST(Any, Int) { + AnyView view0; + EXPECT_EQ(view0.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFINone); + + Optional opt_v0 = view0.as(); + EXPECT_TRUE(!opt_v0.has_value()); + + EXPECT_THROW( + { + try { + [[maybe_unused]] auto v0 = view0.cast(); + } catch (const Error& error) { + EXPECT_EQ(error.kind(), "TypeError"); + std::string what = error.what(); + EXPECT_NE(what.find("Cannot convert from type `None` to `int`"), std::string::npos); + throw; + } + }, + ::tvm::ffi::Error); + + AnyView view1 = 1; + EXPECT_EQ(view1.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFIInt); + EXPECT_EQ(view1.CopyToTVMFFIAny().v_int64, 1); + + auto int_v1 = view1.cast(); + EXPECT_EQ(int_v1, 1); + + int64_t v1 = 2; + view0 = v1; + EXPECT_EQ(view0.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFIInt); + EXPECT_EQ(view0.CopyToTVMFFIAny().v_int64, 2); +} + +TEST(Any, bool) { + AnyView view0; + Optional opt_v0 = view0.as(); + EXPECT_TRUE(!opt_v0.has_value()); + + EXPECT_THROW( + { + try { + [[maybe_unused]] auto v0 = view0.cast(); + } catch (const Error& error) { + EXPECT_EQ(error.kind(), "TypeError"); + std::string what = error.what(); + EXPECT_NE(what.find("Cannot convert from type `None` to `bool`"), std::string::npos); + throw; + } + }, + ::tvm::ffi::Error); + + AnyView view1 = true; + EXPECT_EQ(view1.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFIBool); + EXPECT_EQ(view1.CopyToTVMFFIAny().v_int64, 1); + + auto int_v1 = view1.cast(); + EXPECT_EQ(int_v1, 1); + + bool v1 = false; + view0 = v1; + EXPECT_EQ(view0.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFIBool); + EXPECT_EQ(view0.CopyToTVMFFIAny().v_int64, 0); +} + +TEST(Any, nullptrcmp) { + AnyView view0; + EXPECT_EQ(view0.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFINone); + EXPECT_TRUE(view0 == nullptr); + EXPECT_FALSE(view0 != nullptr); + + view0 = 1; + EXPECT_TRUE(view0 != nullptr); + EXPECT_FALSE(view0 == nullptr); + + Any any0 = view0; + EXPECT_TRUE(any0 != nullptr); + EXPECT_FALSE(any0 == nullptr); + + any0 = nullptr; + EXPECT_TRUE(any0 == nullptr); + EXPECT_FALSE(any0 != nullptr); +} + +TEST(Any, Float) { + AnyView view0; + EXPECT_EQ(view0.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFINone); + + Optional opt_v0 = view0.as(); + EXPECT_TRUE(!opt_v0.has_value()); + + EXPECT_THROW( + { + try { + [[maybe_unused]] auto v0 = view0.cast(); + } catch (const Error& error) { + EXPECT_EQ(error.kind(), "TypeError"); + std::string what = error.what(); + EXPECT_NE(what.find("Cannot convert from type `None` to `float`"), std::string::npos); + throw; + } + }, + ::tvm::ffi::Error); + + AnyView view1_int = 1; + auto float_v1 = view1_int.cast(); + EXPECT_EQ(float_v1, 1); + + AnyView view2 = 2.2; + EXPECT_EQ(view2.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFIFloat); + EXPECT_EQ(view2.CopyToTVMFFIAny().v_float64, 2.2); + + float v1 = 2; + view0 = v1; + EXPECT_EQ(view0.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFIFloat); + EXPECT_EQ(view0.CopyToTVMFFIAny().v_float64, 2); +} + +TEST(Any, Device) { + AnyView view0; + EXPECT_EQ(view0.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFINone); + + Optional opt_v0 = view0.as(); + EXPECT_TRUE(!opt_v0.has_value()); + + EXPECT_THROW( + { + try { + [[maybe_unused]] auto v0 = view0.cast(); + } catch (const Error& error) { + EXPECT_EQ(error.kind(), "TypeError"); + std::string what = error.what(); + EXPECT_NE(what.find("Cannot convert from type `None` to `Device`"), std::string::npos); + throw; + } + }, + ::tvm::ffi::Error); + + DLDevice device{kDLCUDA, 1}; + + AnyView view1_device = device; + auto dtype_v1 = view1_device.cast(); + EXPECT_EQ(dtype_v1.device_type, kDLCUDA); + EXPECT_EQ(dtype_v1.device_id, 1); + + Any any2 = DLDevice{kDLCPU, 0}; + TVMFFIAny ffi_v2 = details::AnyUnsafe::MoveAnyToTVMFFIAny(std::move(any2)); + EXPECT_EQ(ffi_v2.type_index, TypeIndex::kTVMFFIDevice); + EXPECT_EQ(ffi_v2.v_device.device_type, kDLCPU); + EXPECT_EQ(ffi_v2.v_device.device_id, 0); +} + +TEST(Any, DLTensor) { + AnyView view0; + + Optional opt_v0 = view0.as(); + EXPECT_TRUE(!opt_v0.has_value()); + + EXPECT_THROW( + { + try { + [[maybe_unused]] auto v0 = view0.cast(); + } catch (const Error& error) { + EXPECT_EQ(error.kind(), "TypeError"); + std::string what = error.what(); + EXPECT_NE(what.find("Cannot convert from type `None` to `DLTensor*`"), std::string::npos); + throw; + } + }, + ::tvm::ffi::Error); + + DLTensor dltensor; + + AnyView view1_dl = &dltensor; + auto dl_v1 = view1_dl.cast(); + EXPECT_EQ(dl_v1, &dltensor); +} + +TEST(Any, Object) { + AnyView view0; + EXPECT_EQ(view0.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFINone); + + // int object is not nullable + Optional opt_v0 = view0.as(); + EXPECT_TRUE(!opt_v0.has_value()); + + TInt v1(11); + EXPECT_EQ(v1.use_count(), 1); + // view won't increase refcount + AnyView view1 = v1; + EXPECT_EQ(v1.use_count(), 1); + // any will trigger ref count increase + Any any1 = v1; + EXPECT_EQ(v1.use_count(), 2); + // copy to another view + AnyView view2 = any1; + EXPECT_EQ(v1.use_count(), 2); + + // convert to weak raw object ptr + const TIntObj* v1_ptr = view2.cast(); + EXPECT_EQ(v1.use_count(), 2); + EXPECT_EQ(v1_ptr->value, 11); + Any any2 = v1_ptr; + EXPECT_EQ(v1.use_count(), 3); + EXPECT_TRUE(any2.as().has_value()); + + // convert to raw opaque ptr + void* raw_v1_ptr = const_cast(v1_ptr); + any2 = raw_v1_ptr; + EXPECT_TRUE(any2.as().value() == v1_ptr); + + // convert to ObjectRef + { + auto v1_obj_ref = view2.cast(); + EXPECT_EQ(v1.use_count(), 3); + any2 = v1_obj_ref; + EXPECT_EQ(v1.use_count(), 4); + EXPECT_TRUE(any2.as().has_value()); + any2.reset(); + } + + // convert that triggers error + EXPECT_THROW( + { + try { + [[maybe_unused]] auto v0 = view1.cast(); + } catch (const Error& error) { + EXPECT_EQ(error.kind(), "TypeError"); + std::string what = error.what(); + std::cout << what; + EXPECT_NE(what.find("Cannot convert from type `test.Int` to `test.Float`"), + std::string::npos); + throw; + } + }, + ::tvm::ffi::Error); + // Try to convert to number + auto number0 = any1.cast(); + EXPECT_EQ(v1.use_count(), 3); + EXPECT_TRUE(number0.as()); + EXPECT_EQ(number0.as()->value, 11); + EXPECT_TRUE(!any1.as().has_value()); + + auto int1 = view2.cast(); + EXPECT_EQ(v1.use_count(), 4); + any1.reset(); + EXPECT_EQ(v1.use_count(), 3); +} + +TEST(Any, ObjectRefWithFallbackTraits) { + // Test case for TPrimExpr fallback from Any + Any any1 = TPrimExpr("float32", 3.14); + auto v0 = any1.cast(); + EXPECT_EQ(v0->value, 3.14); + EXPECT_EQ(v0->dtype, "float32"); + + any1 = true; + auto v1 = any1.cast(); + EXPECT_EQ(v1->value, 1); + EXPECT_EQ(v1->dtype, "bool"); + + any1 = int64_t(42); + auto v2 = any1.cast(); + EXPECT_EQ(v2->value, 42); + EXPECT_EQ(v2->dtype, "int64"); + + any1 = 2.718; + auto v3 = any1.cast(); + EXPECT_EQ(v3->value, 2.718); + EXPECT_EQ(v3->dtype, "float32"); + + // Test case for TPrimExpr fallback from AnyView + TPrimExpr texpr1("float32", 3.14); + AnyView view1 = texpr1; + auto v4 = view1.cast(); + EXPECT_EQ(v4->value, 3.14); + EXPECT_EQ(v4->dtype, "float32"); + + view1 = true; + auto v5 = view1.cast(); + EXPECT_EQ(v5->value, 1); + EXPECT_EQ(v5->dtype, "bool"); + + view1 = int64_t(42); + auto v6 = view1.cast(); + EXPECT_EQ(v6->value, 42); + EXPECT_EQ(v6->dtype, "int64"); + + view1 = 2.718; + auto v7 = view1.cast(); + EXPECT_EQ(v7->value, 2.718); + EXPECT_EQ(v7->dtype, "float32"); + + // Test case for TPrimExpr fallback from Any with String + any1 = std::string("test_string"); + auto v8 = any1.cast(); + EXPECT_EQ(v8->dtype, "test_string"); + EXPECT_EQ(v8->value, 0); + + // Test case for TPrimExpr fallback from AnyView with String + view1 = "test_string"; + auto v9 = view1.cast(); + EXPECT_EQ(v9->dtype, "test_string"); + EXPECT_EQ(v9->value, 0); +} + +TEST(Any, ObjectMove) { + Any any1 = TPrimExpr("float32", 3.14); + auto v0 = std::move(any1).cast(); + EXPECT_EQ(v0->value, 3.14); + EXPECT_EQ(v0.use_count(), 1); +} + +} // namespace diff --git a/ffi/tests/cpp/test_array.cc b/ffi/tests/cpp/test_array.cc new file mode 100644 index 000000000000..bb0b062c328a --- /dev/null +++ b/ffi/tests/cpp/test_array.cc @@ -0,0 +1,286 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include + +#include "./testing_object.h" + +namespace { + +using namespace tvm::ffi; +using namespace tvm::ffi::testing; + +TEST(Array, Basic) { + Array arr = {TInt(11), TInt(12)}; + TInt v1 = arr[0]; + EXPECT_EQ(v1->value, 11); + EXPECT_EQ(v1.use_count(), 2); + EXPECT_EQ(arr[1]->value, 12); +} + +TEST(Array, COWSet) { + Array arr = {TInt(11), TInt(12)}; + Array arr2 = arr; + EXPECT_EQ(arr.use_count(), 2); + arr.Set(1, TInt(13)); + EXPECT_EQ(arr.use_count(), 1); + EXPECT_EQ(arr[1]->value, 13); + EXPECT_EQ(arr2[1]->value, 12); +} + +TEST(Array, MutateInPlaceForUniqueReference) { + TInt x(1); + Array arr{x, x}; + EXPECT_TRUE(arr.unique()); + auto* before = arr.get(); + + arr.MutateByApply([](TInt) { return TInt(2); }); + auto* after = arr.get(); + EXPECT_EQ(before, after); +} + +TEST(Array, CopyWhenMutatingNonUniqueReference) { + TInt x(1); + Array arr{x, x}; + Array arr2 = arr; + + EXPECT_TRUE(!arr.unique()); + auto* before = arr.get(); + + arr.MutateByApply([](TInt) { return TInt(2); }); + auto* after = arr.get(); + EXPECT_NE(before, after); +} + +TEST(Array, Map) { + // Basic functionality + TInt x(1), y(1); + Array var_arr{x, y}; + Array expr_arr = + var_arr.Map([](TInt var) -> TNumber { return TFloat(static_cast(var->value + 1)); }); + + EXPECT_NE(var_arr.get(), expr_arr.get()); + EXPECT_TRUE(expr_arr[0]->IsInstance()); + EXPECT_TRUE(expr_arr[1]->IsInstance()); +} + +TEST(Array, Iterator) { + Array array{1, 2, 3}; + std::vector vector(array.begin(), array.end()); + EXPECT_EQ(vector[1], 2); +} + +TEST(Array, PushPop) { + Array a; + std::vector b; + for (int i = 0; i < 10; ++i) { + a.push_back(i); + b.push_back(i); + ASSERT_EQ(a.front(), b.front()); + ASSERT_EQ(a.back(), b.back()); + ASSERT_EQ(a.size(), b.size()); + int n = static_cast(a.size()); + for (int j = 0; j < n; ++j) { + ASSERT_EQ(a[j], b[j]); + } + } + for (int i = 9; i >= 0; --i) { + ASSERT_EQ(a.front(), b.front()); + ASSERT_EQ(a.back(), b.back()); + ASSERT_EQ(a.size(), b.size()); + a.pop_back(); + b.pop_back(); + int n = static_cast(a.size()); + for (int j = 0; j < n; ++j) { + ASSERT_EQ(a[j], b[j]); + } + } + ASSERT_EQ(a.empty(), true); +} + +TEST(Array, ResizeReserveClear) { + for (size_t n = 0; n < 10; ++n) { + Array a; + Array b; + a.resize(n); + b.reserve(n); + ASSERT_EQ(a.size(), n); + ASSERT_GE(a.capacity(), n); + a.clear(); + b.clear(); + ASSERT_EQ(a.size(), 0); + ASSERT_EQ(b.size(), 0); + } +} + +TEST(Array, InsertErase) { + Array a; + std::vector b; + for (int n = 1; n <= 10; ++n) { + a.insert(a.end(), n); + b.insert(b.end(), n); + for (int pos = 0; pos <= n; ++pos) { + a.insert(a.begin() + pos, pos); + b.insert(b.begin() + pos, pos); + ASSERT_EQ(a.front(), b.front()); + ASSERT_EQ(a.back(), b.back()); + ASSERT_EQ(a.size(), n + 1); + ASSERT_EQ(b.size(), n + 1); + for (int k = 0; k <= n; ++k) { + ASSERT_EQ(a[k], b[k]); + } + a.erase(a.begin() + pos); + b.erase(b.begin() + pos); + } + ASSERT_EQ(a.front(), b.front()); + ASSERT_EQ(a.back(), b.back()); + ASSERT_EQ(a.size(), n); + } +} + +TEST(Array, InsertEraseRange) { + Array range_a{-1, -2, -3, -4}; + std::vector range_b{-1, -2, -3, -4}; + Array a; + std::vector b; + + static_assert(std::is_same_v); + for (size_t n = 1; n <= 10; ++n) { + a.insert(a.end(), static_cast(n)); + b.insert(b.end(), static_cast(n)); + for (size_t pos = 0; pos <= n; ++pos) { + a.insert(a.begin() + pos, range_a.begin(), range_a.end()); + b.insert(b.begin() + pos, range_b.begin(), range_b.end()); + ASSERT_EQ(a.front(), b.front()); + ASSERT_EQ(a.back(), b.back()); + ASSERT_EQ(a.size(), n + range_a.size()); + ASSERT_EQ(b.size(), n + range_b.size()); + size_t m = n + range_a.size(); + for (size_t k = 0; k < m; ++k) { + ASSERT_EQ(a[k], b[k]); + } + a.erase(a.begin() + pos, a.begin() + pos + range_a.size()); + b.erase(b.begin() + pos, b.begin() + pos + range_b.size()); + } + ASSERT_EQ(a.front(), b.front()); + ASSERT_EQ(a.back(), b.back()); + ASSERT_EQ(a.size(), n); + } +} + +TEST(Array, FuncArrayAnyArg) { + Function fadd_one = + Function::FromUnpacked([](Array a) -> Any { return a[0].cast() + 1; }); + EXPECT_EQ(fadd_one(Array{1}).cast(), 2); +} + +TEST(Array, MapUniquePropogation) { + // Basic functionality + Array var_arr{TInt(1), TInt(2)}; + var_arr.MutateByApply([](TInt x) -> TInt { + EXPECT_TRUE(x.unique()); + return x; + }); +} + +TEST(Array, AnyImplicitConversion) { + Array arr0_mixed = {11.1, 1}; + EXPECT_EQ(arr0_mixed[1].cast(), 1); + + AnyView view0 = arr0_mixed; + auto arr0_float = view0.cast>(); + // they are not the same because arr_mixed + // stores arr_mixed[1] as int but we need to convert to float + EXPECT_TRUE(!arr0_float.same_as(arr0_mixed)); + EXPECT_EQ(arr0_float[1], 1.0); + + Any any1 = arr0_float; + // if storage check passes, the same array get returned + auto arr1_float = any1.cast>(); + EXPECT_TRUE(arr1_float.same_as(arr0_float)); + // total count equals 3 include any1 + EXPECT_EQ(arr1_float.use_count(), 3); + + // convert to Array do not need any conversion + auto arr1_mixed = any1.cast>(); + EXPECT_TRUE(arr1_mixed.same_as(arr1_float)); + EXPECT_EQ(arr1_float.use_count(), 4); +} + +TEST(Array, AnyConvertCheck) { + Array arr = {11.1, 1}; + EXPECT_EQ(arr[1].cast(), 1); + + AnyView view0 = arr; + auto arr1 = view0.cast>(); + EXPECT_EQ(arr1[0], 11.1); + EXPECT_EQ(arr1[1], 1.0); + + Any any1 = arr; + + EXPECT_THROW( + { + try { + [[maybe_unused]] auto arr2 = any1.cast>(); + } catch (const Error& error) { + EXPECT_EQ(error.kind(), "TypeError"); + std::string what = error.what(); + EXPECT_NE(what.find("Cannot convert from type `Array[index 0: float]` to `Array`"), + std::string::npos); + throw; + } + }, + ::tvm::ffi::Error); + + Array> arr_nested = {{}, {TInt(1), TFloat(2)}}; + any1 = arr_nested; + auto arr1_nested = any1.cast>>(); + EXPECT_EQ(arr1_nested.use_count(), 3); + + EXPECT_THROW( + { + try { + [[maybe_unused]] auto arr2 = any1.cast>>(); + } catch (const Error& error) { + EXPECT_EQ(error.kind(), "TypeError"); + std::string what = error.what(); + EXPECT_NE(what.find("`Array[index 1: Array[index 0: test.Int]]` to `Array>`"), + std::string::npos); + throw; + } + }, + ::tvm::ffi::Error); +} + +TEST(Array, Upcast) { + Array a0 = {1, 2, 3}; + Array a1 = a0; + EXPECT_EQ(a1[0].cast(), 1); + EXPECT_EQ(a1[1].cast(), 2); + EXPECT_EQ(a1[2].cast(), 3); + + Array> a2 = {a0}; + Array> a3 = a2; + Array> a4 = a2; + + static_assert(details::type_contains_v, Array>); + static_assert(details::type_contains_v>); +} + +} // namespace diff --git a/ffi/tests/cpp/test_c_ffi_abi.cc b/ffi/tests/cpp/test_c_ffi_abi.cc new file mode 100644 index 000000000000..1efceef2971a --- /dev/null +++ b/ffi/tests/cpp/test_c_ffi_abi.cc @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include + +namespace { + +TEST(ABIHeaderAlignment, Default) { + TVMFFIObject value; + value.type_index = 10; + EXPECT_EQ(reinterpret_cast(&value)->type_index, 10); + static_assert(sizeof(TVMFFIObject) == 16, "TVMFFIObject must be 16 bytes"); +} + +} // namespace diff --git a/ffi/tests/cpp/test_dtype.cc b/ffi/tests/cpp/test_dtype.cc new file mode 100644 index 000000000000..e31df8761db0 --- /dev/null +++ b/ffi/tests/cpp/test_dtype.cc @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include + +namespace { + +using namespace tvm::ffi; + +TEST(DType, StringConversion) { + DLDataType dtype = DLDataType{kDLFloat, 32, 1}; + EXPECT_EQ(DLDataTypeToString(dtype), "float32"); + EXPECT_EQ(StringToDLDataType("float32"), dtype); + + dtype = DLDataType{kDLInt, 16, 2}; + EXPECT_EQ(DLDataTypeToString(dtype), "int16x2"); + EXPECT_EQ(StringToDLDataType("int16x2"), dtype); + + dtype = DLDataType{kDLOpaqueHandle, 0, 0}; + EXPECT_EQ(DLDataTypeToString(dtype), ""); + EXPECT_EQ(StringToDLDataType("void"), dtype); + + // test bfloat with lanes + dtype = DLDataType{kDLBfloat, 16, 2}; + EXPECT_EQ(DLDataTypeToString(dtype), "bfloat16x2"); + EXPECT_EQ(StringToDLDataType("bfloat16x2"), dtype); + + // test float8 + dtype = DLDataType{kDLFloat8_e4m3fn, 8, 2}; + EXPECT_EQ(DLDataTypeToString(dtype), "float8_e4m3fnx2"); + EXPECT_EQ(StringToDLDataType("float8_e4m3fnx2"), dtype); +} + +TEST(DType, StringConversionAllDLPackTypes) { + std::vector> test_cases = { + {DLDataType{kDLFloat, 32, 1}, "float32"}, + {DLDataType{kDLInt, 16, 1}, "int16"}, + {DLDataType{kDLUInt, 16, 1}, "uint16"}, + {DLDataType{kDLBfloat, 16, 1}, "bfloat16"}, + {DLDataType{kDLFloat8_e3m4, 8, 1}, "float8_e3m4"}, + {DLDataType{kDLFloat8_e4m3, 8, 1}, "float8_e4m3"}, + {DLDataType{kDLFloat8_e4m3b11fnuz, 8, 1}, "float8_e4m3b11fnuz"}, + {DLDataType{kDLFloat8_e4m3fn, 8, 1}, "float8_e4m3fn"}, + {DLDataType{kDLFloat8_e4m3fnuz, 8, 1}, "float8_e4m3fnuz"}, + {DLDataType{kDLFloat8_e5m2, 8, 1}, "float8_e5m2"}, + {DLDataType{kDLFloat8_e5m2fnuz, 8, 1}, "float8_e5m2fnuz"}, + {DLDataType{kDLFloat8_e8m0fnu, 8, 1}, "float8_e8m0fnu"}, + {DLDataType{kDLFloat6_e2m3fn, 6, 1}, "float6_e2m3fn"}, + {DLDataType{kDLFloat6_e3m2fn, 6, 1}, "float6_e3m2fn"}, + {DLDataType{kDLFloat4_e2m1fn, 4, 1}, "float4_e2m1fn"}, + }; + + for (const auto& [dtype, str] : test_cases) { + EXPECT_EQ(DLDataTypeToString(dtype), str); + EXPECT_EQ(StringToDLDataType(str), dtype); + } +} + +TEST(DataType, AnyConversion) { + AnyView view0; + EXPECT_EQ(view0.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFINone); + + Optional opt_v0 = view0.as(); + EXPECT_TRUE(!opt_v0.has_value()); + + EXPECT_THROW( + { + try { + [[maybe_unused]] auto v0 = view0.cast(); + } catch (const Error& error) { + EXPECT_EQ(error.kind(), "TypeError"); + std::string what = error.what(); + EXPECT_NE(what.find("Cannot convert from type `None` to `DataType`"), std::string::npos); + throw; + } + }, + ::tvm::ffi::Error); + + DLDataType dtype{kDLFloat, 32, 1}; + + AnyView view1_dtype = dtype; + auto dtype_v1 = view1_dtype.cast(); + EXPECT_EQ(dtype_v1.code, kDLFloat); + EXPECT_EQ(dtype_v1.bits, 32); + EXPECT_EQ(dtype_v1.lanes, 1); + + Any any2 = DLDataType{kDLInt, 16, 2}; + TVMFFIAny ffi_v2 = details::AnyUnsafe::MoveAnyToTVMFFIAny(std::move(any2)); + EXPECT_EQ(ffi_v2.type_index, TypeIndex::kTVMFFIDataType); + EXPECT_EQ(ffi_v2.v_dtype.code, kDLInt); + EXPECT_EQ(ffi_v2.v_dtype.bits, 16); + EXPECT_EQ(ffi_v2.v_dtype.lanes, 2); +} + +// String can be automatically converted to DLDataType +TEST(DataType, AnyConversionWithString) { + AnyView view0 = "float32"; + + Optional opt_v0 = view0.as(); + DLDataType dtype_v0 = opt_v0.value(); + EXPECT_EQ(dtype_v0.code, kDLFloat); + EXPECT_EQ(dtype_v0.bits, 32); + EXPECT_EQ(dtype_v0.lanes, 1); + + Any any = String("bfloat16x2"); + Optional opt_v1 = any.as(); + EXPECT_EQ(opt_v1.value().code, kDLBfloat); + EXPECT_EQ(opt_v1.value().bits, 16); + EXPECT_EQ(opt_v1.value().lanes, 2); +} +} // namespace diff --git a/ffi/tests/cpp/test_error.cc b/ffi/tests/cpp/test_error.cc new file mode 100644 index 000000000000..9938603a47ba --- /dev/null +++ b/ffi/tests/cpp/test_error.cc @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include + +namespace { + +using namespace tvm::ffi; + +void ThrowRuntimeError() { TVM_FFI_THROW(RuntimeError) << "test0"; } + +TEST(Error, Traceback) { + EXPECT_THROW( + { + try { + ThrowRuntimeError(); + } catch (const Error& error) { + EXPECT_EQ(error.message(), "test0"); + EXPECT_EQ(error.kind(), "RuntimeError"); + std::string what = error.what(); + EXPECT_NE(what.find("line"), std::string::npos); + EXPECT_NE(what.find("ThrowRuntimeError"), std::string::npos); + EXPECT_NE(what.find("RuntimeError: test0"), std::string::npos); + throw; + } + }, + ::tvm::ffi::Error); +} + +TEST(CheckError, Traceback) { + EXPECT_THROW( + { + try { + TVM_FFI_ICHECK_GT(2, 3); + } catch (const Error& error) { + EXPECT_EQ(error.kind(), "InternalError"); + std::string what = error.what(); + EXPECT_NE(what.find("line"), std::string::npos); + EXPECT_NE(what.find("2 > 3"), std::string::npos); + throw; + } + }, + ::tvm::ffi::Error); +} + +TEST(Error, AnyConvert) { + Any any = Error("TypeError", "here", "test0"); + Optional opt_err = any.as(); + EXPECT_EQ(opt_err.value().kind(), "TypeError"); + EXPECT_EQ(opt_err.value().message(), "here"); +} +} // namespace diff --git a/ffi/tests/cpp/test_function.cc b/ffi/tests/cpp/test_function.cc new file mode 100644 index 000000000000..fbdc580f3b2b --- /dev/null +++ b/ffi/tests/cpp/test_function.cc @@ -0,0 +1,246 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include +#include + +#include "./testing_object.h" + +namespace { + +using namespace tvm::ffi; +using namespace tvm::ffi::testing; + +TEST(Func, FromPacked) { + Function fadd1 = Function::FromPacked([](const AnyView* args, int32_t num_args, Any* rv) { + EXPECT_EQ(num_args, 1); + int32_t a = args[0].cast(); + *rv = a + 1; + }); + int b = fadd1(1).cast(); + EXPECT_EQ(b, 2); + + Function fadd2 = Function::FromPacked([](const AnyView* args, int32_t num_args, Any* rv) { + EXPECT_EQ(num_args, 1); + auto a = args[0].cast(); + EXPECT_EQ(a.use_count(), 2); + *rv = a->value + 1; + }); + EXPECT_EQ(fadd2(TInt(12)).cast(), 13); +} + +TEST(Func, PackedArgs) { + Function fadd1 = Function::FromPacked([](PackedArgs args, Any* rv) { + EXPECT_EQ(args.size(), 1); + int32_t a = args[0].cast(); + *rv = a + 1; + }); + int b = fadd1(1).cast(); + EXPECT_EQ(b, 2); + + Function fadd2 = Function::FromPacked([](PackedArgs args, Any* rv) { + EXPECT_EQ(args.size(), 1); + TInt a = args[0].cast(); + EXPECT_EQ(a.use_count(), 2); + *rv = a->value + 1; + }); + EXPECT_EQ(fadd2(TInt(12)).cast(), 13); + + TInt v(12); + AnyView data[3]; + PackedArgs::Fill(data, 3, 1, v); + EXPECT_EQ(data[0].cast(), 3); + EXPECT_EQ(data[1].cast(), 1); + EXPECT_EQ(data[2].cast()->value, 12); +} + +TEST(Func, FromUnpacked) { + // try decution + Function fadd1 = Function::FromUnpacked([](const int32_t& a) -> int { return a + 1; }); + int b = fadd1(1).cast(); + EXPECT_EQ(b, 2); + + // convert that triggers error + EXPECT_THROW( + { + try { + fadd1(1.1); + } catch (const Error& error) { + EXPECT_EQ(error.kind(), "TypeError"); + EXPECT_EQ(error.message(), + "Mismatched type on argument #0 when calling: `(0: int) -> int`. " + "Expected `int` but got `float`"); + throw; + } + }, + ::tvm::ffi::Error); + + // convert that triggers error + EXPECT_THROW( + { + try { + fadd1(); + } catch (const Error& error) { + EXPECT_EQ(error.kind(), "TypeError"); + EXPECT_EQ(error.message(), + "Mismatched number of arguments when calling: `(0: int) -> int`. " + "Expected 1 but got 0 arguments"); + throw; + } + }, + ::tvm::ffi::Error); + + // try decution + Function fpass_and_return = Function::FromUnpacked( + [](TInt x, int value, AnyView z) -> Function { + EXPECT_EQ(x.use_count(), 2); + EXPECT_EQ(x->value, value); + if (auto opt = z.as()) { + EXPECT_EQ(value, *opt); + } + return Function::FromUnpacked([value](int x) -> int { return x + value; }); + }, + "fpass_and_return"); + TInt a(11); + auto fret = fpass_and_return(std::move(a), 11, 11).cast(); + EXPECT_EQ(fret(12).cast(), 23); + + EXPECT_THROW( + { + try { + fpass_and_return(); + } catch (const Error& error) { + EXPECT_EQ(error.kind(), "TypeError"); + EXPECT_EQ(error.message(), + "Mismatched number of arguments when calling: " + "`fpass_and_return(0: test.Int, 1: int, 2: AnyView) -> object.Function`. " + "Expected 3 but got 0 arguments"); + throw; + } + }, + ::tvm::ffi::Error); + + Function fconcact = + Function::FromUnpacked([](const String& a, const String& b) -> String { return a + b; }); + EXPECT_EQ(fconcact("abc", "def").cast(), "abcdef"); +} + +TEST(Func, PassReturnAny) { + Function fadd_one = Function::FromUnpacked([](Any a) -> Any { return a.cast() + 1; }); + EXPECT_EQ(fadd_one(1).cast(), 2); +} + +TEST(Func, Global) { + Function::SetGlobal("testing.add1", + Function::FromUnpacked([](const int32_t& a) -> int { return a + 1; })); + auto fadd1 = Function::GetGlobalRequired("testing.add1"); + int b = fadd1(1).cast(); + EXPECT_EQ(b, 2); + auto fnot_exist = Function::GetGlobal("testing.not_existing_func"); + EXPECT_TRUE(!fnot_exist); + + auto fname_functor = + Function::GetGlobal("ffi.FunctionListGlobalNamesFunctor").value()().cast(); + Array names; + int len = fname_functor(-1).cast(); + for (int i = 0; i < len; ++i) { + names.push_back(fname_functor(i).cast()); + } + EXPECT_TRUE(std::find(names.begin(), names.end(), "testing.add1") != names.end()); +} + +TEST(Func, TypedFunction) { + TypedFunction fadd1 = [](int a) -> int { return a + 1; }; + EXPECT_EQ(fadd1(1), 2); + + TypedFunction fadd2([](int a) -> int { return a + 2; }); + EXPECT_EQ(fadd2(1), 3); + EXPECT_EQ(fadd2.packed()(1).cast(), 3); + + TypedFunction fcheck_int; + EXPECT_TRUE(fcheck_int == nullptr); + fcheck_int = [](int a) -> void { EXPECT_EQ(a, 1); }; + fcheck_int(1); +} + +TEST(Func, TypedFunctionAsAny) { + TypedFunction fadd1 = [](int a) -> int { return a + 1; }; + Any fany(std::move(fadd1)); + EXPECT_TRUE(fadd1 == nullptr); + auto fadd1_dup = fany.cast>(); + EXPECT_EQ(fadd1_dup(1), 2); +} + +TEST(Func, TypedFunctionAsAnyView) { + TypedFunction fadd2 = [](int a) -> int { return a + 2; }; + AnyView fview(fadd2); + auto fadd2_dup = fview.cast>(); + EXPECT_EQ(fadd2_dup(1), 3); +} + +TEST(Func, ObjectRefWithFallbackTraits) { + // test cases to test automatic type conversion via ObjectRefWithFallbackTraits + // through TPrimExpr + Function freturn_primexpr = Function::FromUnpacked([](TPrimExpr a) -> TPrimExpr { return a; }); + + auto result_int = freturn_primexpr(1).cast(); + EXPECT_EQ(result_int->dtype, "int64"); + EXPECT_EQ(result_int->value, 1); + + // Test case for float + auto result_float = freturn_primexpr(2.5).cast(); + EXPECT_EQ(result_float->dtype, "float32"); + EXPECT_EQ(result_float->value, 2.5); + + // Test case for bool + auto result_bool = freturn_primexpr(true).cast(); + EXPECT_EQ(result_bool->dtype, "bool"); + EXPECT_EQ(result_bool->value, 1); + + // Test case for string + auto result_string = freturn_primexpr("test_string").cast(); + EXPECT_EQ(result_string->dtype, "test_string"); + EXPECT_EQ(result_string->value, 0); + + EXPECT_THROW( + { + try { + freturn_primexpr(TInt(1)); + } catch (const Error& error) { + EXPECT_EQ(error.kind(), "TypeError"); + EXPECT_EQ( + error.message(), + "Mismatched type on argument #0 when calling: `(0: test.PrimExpr) -> test.PrimExpr`. " + "Expected `test.PrimExpr` but got `test.Int`"); + throw; + } + }, + ::tvm::ffi::Error); +} + +TVM_FFI_REGISTER_GLOBAL("testing.Int_GetValue").set_body_method(&TIntObj::GetValue); + +TEST(Func, Register) { + Function fget_value = Function::GetGlobalRequired("testing.Int_GetValue"); + TInt a(12); + EXPECT_EQ(fget_value(a).cast(), 12); +} +} // namespace diff --git a/ffi/tests/cpp/test_map.cc b/ffi/tests/cpp/test_map.cc new file mode 100644 index 000000000000..1c43230bbc1f --- /dev/null +++ b/ffi/tests/cpp/test_map.cc @@ -0,0 +1,359 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include + +#include "./testing_object.h" + +namespace { + +using namespace tvm::ffi; +using namespace tvm::ffi::testing; + +TEST(Map, Basic) { + Map map0; + TInt k0(0); + map0.Set(k0, 1); + + EXPECT_EQ(map0.size(), 1); + + map0.Set(k0, 2); + EXPECT_EQ(map0.size(), 1); + + auto it = map0.find(k0); + EXPECT_TRUE(it != map0.end()); + EXPECT_EQ((*it).second, 2); +} + +TEST(Map, PODKey) { + Map map0; + + // int as key + map0.Set(1, 2); + // float key is different + map0.Set(1.1, 3); + EXPECT_EQ(map0.size(), 2); + + auto it = map0.find(1.1); + EXPECT_TRUE(it != map0.end()); + EXPECT_EQ((*it).second.cast(), 3); +} + +TEST(Map, Object) { + TInt x(1); + TInt z(100); + TInt zz(1000); + Map dict{{x, z}, {z, zz}}; + EXPECT_EQ(dict.size(), 2); + EXPECT_TRUE(dict[x].same_as(z)); + EXPECT_TRUE(dict.count(z)); + EXPECT_TRUE(!dict.count(zz)); +} + +TEST(Map, Str) { + TInt x(1); + TInt z(100); + Map dict{{"x", z}, {"z", z}}; + EXPECT_EQ(dict.size(), 2); + EXPECT_TRUE(dict["x"].same_as(z)); +} + +TEST(Map, Mutate) { + TInt x(1); + TInt z(100); + TInt zz(1000); + Map dict{{x, z}, {z, zz}}; + + EXPECT_TRUE(dict[x].same_as(z)); + dict.Set(x, zz); + auto dict2 = dict; + EXPECT_EQ(dict2.count(z), 1); + dict.Set(zz, x); + EXPECT_EQ(dict2.count(zz), 0); + EXPECT_EQ(dict.count(zz), 1); + + auto it = dict.find(zz); + EXPECT_TRUE(it != dict.end() && (*it).second.same_as(x)); + + it = dict2.find(zz); + EXPECT_TRUE(it == dict2.end()); +} + +TEST(Map, Clear) { + TInt x(1); + TInt z(100); + Map dict{{x, z}, {z, z}}; + EXPECT_EQ(dict.size(), 2); + dict.clear(); + EXPECT_EQ(dict.size(), 0); +} + +TEST(Map, Insert) { + auto check = [](const Map& result, + std::unordered_map expected) { + EXPECT_EQ(result.size(), expected.size()); + for (const auto& kv : result) { + EXPECT_TRUE(expected.count(kv.first)); + EXPECT_EQ(expected[kv.first], kv.second); + expected.erase(kv.first); + } + }; + Map result; + std::unordered_map expected; + char key = 'a'; + int64_t val = 1; + for (int i = 0; i < 26; ++i, ++key, ++val) { + std::string s(1, key); + result.Set(s, val); + expected[s] = val; + check(result, expected); + } +} + +TEST(Map, Erase) { + auto check = [](const Map& result, + std::unordered_map expected) { + EXPECT_EQ(result.size(), expected.size()); + for (const auto& kv : result) { + EXPECT_TRUE(expected.count(kv.first)); + EXPECT_EQ(expected[kv.first], kv.second); + expected.erase(kv.first); + } + }; + Map map{{"a", 1}, {"b", 2}, {"c", 3}, {"d", 4}, {"e", 5}}; + std::unordered_map stl; + std::transform(map.begin(), map.end(), std::inserter(stl, stl.begin()), + [](auto&& p) { return std::make_pair(p.first, p.second); }); + for (char c = 'a'; c <= 'e'; ++c) { + Map result = map; + std::unordered_map expected(stl); + std::string key(1, c); + result.erase(key); + expected.erase(key); + check(result, expected); + } +} + +TEST(Map, AnyImplicitConversion) { + Map map0; + map0.Set(1, 2); + map0.Set(2, 3.1); + EXPECT_EQ(map0.size(), 2); + + // check will trigger copy + AnyView view0 = map0; + auto map1 = view0.cast>(); + EXPECT_TRUE(!map1.same_as(map0)); + EXPECT_EQ(map1[1], 2); + EXPECT_EQ(map1[2], 3.1); + EXPECT_EQ(map1.use_count(), 1); + + auto map2 = view0.cast>(); + EXPECT_TRUE(map2.same_as(map0)); + EXPECT_EQ(map2.use_count(), 2); + + auto map3 = view0.cast>(); + EXPECT_TRUE(!map3.same_as(map0)); + EXPECT_EQ(map3.use_count(), 1); + + Map map4{{"yes", 1.1}, {"no", 2.2}}; + Any any1 = map4; + + auto map5 = any1.cast>(); + EXPECT_TRUE(map5.same_as(map4)); + EXPECT_EQ(map5.use_count(), 3); + + auto map6 = any1.cast>(); + EXPECT_TRUE(map6.same_as(map4)); + EXPECT_EQ(map6.use_count(), 4); + + EXPECT_EQ(map6["yes"].cast(), 1.1); + EXPECT_EQ(map6["no"].cast(), 2.2); + + auto map7 = any1.cast>(); + EXPECT_TRUE(map7.same_as(map4)); + EXPECT_EQ(map7.use_count(), 5); + + auto map8 = any1.cast>(); + EXPECT_TRUE(!map8.same_as(map4)); + EXPECT_EQ(map8.use_count(), 1); + EXPECT_EQ(map8["yes"]->value, 1.1); + EXPECT_EQ(map8["no"]->value, 2.2); +} + +TEST(Map, AnyConvertCheck) { + Map map = {{11, 1.1}}; + EXPECT_EQ(map[11].cast(), 1.1); + + AnyView view0 = map; + auto arr1 = view0.cast>(); + EXPECT_EQ(arr1[11], 1.1); + + Any any1 = map; + using WrongMap = Map; + + EXPECT_THROW( + { + try { + [[maybe_unused]] auto arr2 = any1.cast(); + } catch (const Error& error) { + EXPECT_EQ(error.kind(), "TypeError"); + std::string what = error.what(); + EXPECT_NE( + what.find( + "Cannot convert from type `Map[K, some value is float]` to `Map`"), + std::string::npos); + throw; + } + }, + ::tvm::ffi::Error); + + using WrongMap2 = Map; + EXPECT_THROW( + { + try { + [[maybe_unused]] auto arr2 = any1.cast(); + } catch (const Error& error) { + EXPECT_EQ(error.kind(), "TypeError"); + std::string what = error.what(); + EXPECT_NE(what.find("Cannot convert from type `Map[some key is int, V]` to " + "`Map`"), + std::string::npos); + throw; + } + }, + ::tvm::ffi::Error); +} + +TEST(Map, PackedFuncGetItem) { + Function f = Function::FromUnpacked([](const MapObj* n, const Any& k) -> Any { return n->at(k); }, + "map_get_item"); + Map map{{"x", 1}, {"y", 2}}; + Any k("x"); + Any v = f(map, k); + EXPECT_EQ(v.cast(), 1); +} + +TEST(Map, Upcast) { + Map m0 = {{1, 2}, {3, 4}}; + Map m1 = m0; + EXPECT_EQ(m1[1].cast(), 2); + EXPECT_EQ(m1[3].cast(), 4); + static_assert(details::type_contains_v, Map>); + + Map> m2 = {{"x", {1}}, {"y", {2}}}; + Map> m3 = m2; +} + +template +void PrintMap(const Map& m0) { + std::cout << "{"; + for (auto it = m0.begin(); it != m0.end(); ++it) { + if (it != m0.begin()) { + std::cout << ", "; + } + std::cout << (*it).first << ": " << (*it).second; + } + std::cout << "}" << std::endl; +} + +TEST(Map, MapInsertOrder) { + // test that map preserves the insertion order + auto get_reverse_order = [](size_t size) { + std::vector reverse_order; + for (int i = static_cast(size); i != 0; --i) { + reverse_order.push_back(i - 1); + } + return reverse_order; + }; + + auto check_map = [&](Map m0, size_t size, const std::vector& order) { + auto lhs = m0.begin(); + auto rhs = order.begin(); + while (lhs != m0.end()) { + TVM_FFI_ICHECK_EQ((*lhs).first, "hello" + std::to_string(*rhs)); + TVM_FFI_ICHECK_EQ((*lhs).second, *rhs); + ++lhs; + ++rhs; + } + lhs = m0.end(); + rhs = order.begin() + size; + do { + --lhs; + --rhs; + TVM_FFI_ICHECK_EQ((*lhs).first, "hello" + std::to_string(*rhs)); + TVM_FFI_ICHECK_EQ((*lhs).second, *rhs); + } while (lhs != m0.begin()); + }; + + auto check_order = [&](std::vector order) { + Map m0; + for (size_t i = 0; i < order.size(); ++i) { + m0.Set("hello" + std::to_string(order[i]), order[i]); + check_map(m0, i + 1, order); + } + check_map(m0, order.size(), order); + // erase a few items + m0.erase("hello" + std::to_string(order[0])); + auto item0 = order[0]; + order.erase(order.begin()); + check_map(m0, order.size(), order); + // erase the middle part + if (order.size() > 1) { + m0.erase("hello" + std::to_string(order[1])); + order.erase(order.begin() + 1); + check_map(m0, order.size(), order); + } + // erase the end + m0.erase("hello" + std::to_string(order.back())); + auto item2 = order.back(); + order.erase(order.end() - 1); + check_map(m0, order.size(), order); + EXPECT_NE(m0.size(), 0); + // put back some items + order.push_back(item2); + m0.Set("hello" + std::to_string(item2), item2); + check_map(m0, order.size(), order); + order.push_back(item0); + m0.Set("hello" + std::to_string(item0), item0); + check_map(m0, order.size(), order); + }; + // test with 17 items: DenseMapObj + check_order(get_reverse_order(17)); + // test with 4 items: SmallMapObj + check_order(get_reverse_order(4)); +} + +TEST(Map, EmptyIter) { + Map m0; + EXPECT_EQ(m0.begin(), m0.end()); + // create a big map and then erase to keep a dense map empty + for (int i = 0; i < 10; ++i) { + m0.Set("hello" + std::to_string(i), i); + } + for (int i = 0; i < 10; ++i) { + m0.erase("hello" + std::to_string(i)); + } + EXPECT_EQ(m0.size(), 0); + // now m0 is dense map with all empty slots + EXPECT_EQ(m0.begin(), m0.end()); +} +} // namespace diff --git a/ffi/tests/cpp/test_ndarray.cc b/ffi/tests/cpp/test_ndarray.cc new file mode 100644 index 000000000000..3d7b00cd33c3 --- /dev/null +++ b/ffi/tests/cpp/test_ndarray.cc @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include + +namespace { + +using namespace tvm::ffi; + +struct CPUNDAlloc { + void AllocData(DLTensor* tensor) { tensor->data = malloc(GetDataSize(*tensor)); } + void FreeData(DLTensor* tensor) { free(tensor->data); } +}; + +inline NDArray Empty(Shape shape, DLDataType dtype, DLDevice device) { + return NDArray::FromNDAlloc(CPUNDAlloc(), shape, dtype, device); +} + +TEST(NDArray, Basic) { + NDArray nd = Empty(Shape({1, 2, 3}), DLDataType({kDLFloat, 32, 1}), DLDevice({kDLCPU, 0})); + Shape shape = nd.shape(); + EXPECT_EQ(shape.size(), 3); + EXPECT_EQ(shape[0], 1); + EXPECT_EQ(shape[1], 2); + EXPECT_EQ(shape[2], 3); + EXPECT_EQ(nd.dtype(), DLDataType({kDLFloat, 32, 1})); + for (int64_t i = 0; i < shape.Product(); ++i) { + reinterpret_cast(nd->data)[i] = static_cast(i); + } + + Any any0 = nd; + NDArray nd2 = any0.as().value(); + EXPECT_EQ(nd2.shape(), shape); + EXPECT_EQ(nd2.dtype(), DLDataType({kDLFloat, 32, 1})); + for (int64_t i = 0; i < shape.Product(); ++i) { + EXPECT_EQ(reinterpret_cast(nd2->data)[i], i); + } + + EXPECT_EQ(nd.IsContiguous(), true); + EXPECT_EQ(nd2.use_count(), 3); +} + +TEST(NDArray, DLPack) { + NDArray nd = Empty({1, 2, 3}, DLDataType({kDLInt, 16, 1}), DLDevice({kDLCPU, 0})); + DLManagedTensor* dlpack = nd.ToDLPack(); + EXPECT_EQ(dlpack->dl_tensor.ndim, 3); + EXPECT_EQ(dlpack->dl_tensor.shape[0], 1); + EXPECT_EQ(dlpack->dl_tensor.shape[1], 2); + EXPECT_EQ(dlpack->dl_tensor.shape[2], 3); + EXPECT_EQ(dlpack->dl_tensor.dtype.code, kDLInt); + EXPECT_EQ(dlpack->dl_tensor.dtype.bits, 16); + EXPECT_EQ(dlpack->dl_tensor.dtype.lanes, 1); + EXPECT_EQ(dlpack->dl_tensor.device.device_type, kDLCPU); + EXPECT_EQ(dlpack->dl_tensor.device.device_id, 0); + EXPECT_EQ(dlpack->dl_tensor.byte_offset, 0); + EXPECT_EQ(dlpack->dl_tensor.strides, nullptr); + EXPECT_EQ(nd.use_count(), 2); + { + NDArray nd2 = NDArray::FromDLPack(dlpack); + EXPECT_EQ(nd2.use_count(), 1); + EXPECT_EQ(nd2->data, nd->data); + EXPECT_EQ(nd.use_count(), 2); + EXPECT_EQ(nd2.use_count(), 1); + } + EXPECT_EQ(nd.use_count(), 1); +} + +TEST(NDArray, DLPackVersioned) { + DLDataType dtype = DLDataType({kDLFloat4_e2m1fn, 4, 1}); + EXPECT_EQ(GetDataSize(2, dtype), 2 * 4 / 8); + NDArray nd = Empty({2}, dtype, DLDevice({kDLCPU, 0})); + DLManagedTensorVersioned* dlpack = nd.ToDLPackVersioned(); + EXPECT_EQ(dlpack->version.major, DLPACK_MAJOR_VERSION); + EXPECT_EQ(dlpack->version.minor, DLPACK_MINOR_VERSION); + EXPECT_EQ(dlpack->dl_tensor.ndim, 1); + EXPECT_EQ(dlpack->dl_tensor.shape[0], 2); + EXPECT_EQ(dlpack->dl_tensor.dtype.code, kDLFloat4_e2m1fn); + EXPECT_EQ(dlpack->dl_tensor.dtype.bits, 4); + EXPECT_EQ(dlpack->dl_tensor.dtype.lanes, 1); + EXPECT_EQ(dlpack->dl_tensor.device.device_type, kDLCPU); + EXPECT_EQ(dlpack->dl_tensor.device.device_id, 0); + EXPECT_EQ(dlpack->dl_tensor.byte_offset, 0); + EXPECT_EQ(dlpack->dl_tensor.strides, nullptr); + + EXPECT_EQ(nd.use_count(), 2); + { + NDArray nd2 = NDArray::FromDLPackVersioned(dlpack); + EXPECT_EQ(nd2.use_count(), 1); + EXPECT_EQ(nd2->data, nd->data); + EXPECT_EQ(nd.use_count(), 2); + EXPECT_EQ(nd2.use_count(), 1); + } + EXPECT_EQ(nd.use_count(), 1); +} +} // namespace diff --git a/ffi/tests/cpp/test_object.cc b/ffi/tests/cpp/test_object.cc new file mode 100644 index 000000000000..c370ff51a42f --- /dev/null +++ b/ffi/tests/cpp/test_object.cc @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include + +#include "./testing_object.h" + +namespace { + +using namespace tvm::ffi; +using namespace tvm::ffi::testing; + +TEST(Object, RefCounter) { + ObjectPtr a = make_object(11); + ObjectPtr b = a; + + EXPECT_EQ(a->value, 11); + + EXPECT_EQ(a.use_count(), 2); + ObjectPtr aa = make_object(*a); + EXPECT_EQ(aa.use_count(), 1); + EXPECT_EQ(aa->value, 11); + + b.reset(); + EXPECT_EQ(a.use_count(), 1); + EXPECT_TRUE(b == nullptr); + EXPECT_EQ(b.use_count(), 0); + + ObjectPtr c = std::move(a); + EXPECT_EQ(c.use_count(), 1); + EXPECT_TRUE(a == nullptr); + + EXPECT_EQ(c->value, 11); +} + +TEST(Object, TypeInfo) { + const TypeInfo* info = TVMFFIGetTypeInfo(TIntObj::RuntimeTypeIndex()); + EXPECT_TRUE(info != nullptr); + EXPECT_EQ(info->type_index, TIntObj::RuntimeTypeIndex()); + EXPECT_EQ(info->type_depth, 2); + EXPECT_EQ(info->type_acenstors[0], Object::_type_index); + EXPECT_EQ(info->type_acenstors[1], TNumberObj::_type_index); + EXPECT_GE(info->type_index, TypeIndex::kTVMFFIDynObjectBegin); +} + +TEST(Object, InstanceCheck) { + ObjectPtr a = make_object(11); + ObjectPtr b = make_object(11); + + EXPECT_TRUE(a->IsInstance()); + EXPECT_TRUE(a->IsInstance()); + EXPECT_TRUE(a->IsInstance()); + EXPECT_TRUE(!a->IsInstance()); + + EXPECT_TRUE(a->IsInstance()); + EXPECT_TRUE(b->IsInstance()); + EXPECT_TRUE(!b->IsInstance()); + EXPECT_TRUE(b->IsInstance()); +} + +TEST(ObjectRef, as) { + ObjectRef a = TInt(10); + ObjectRef b = TFloat(20); + // nullable object + ObjectRef c(nullptr); + + EXPECT_TRUE(a.as() != nullptr); + EXPECT_TRUE(a.as() == nullptr); + EXPECT_TRUE(a.as() != nullptr); + + EXPECT_TRUE(b.as() == nullptr); + EXPECT_TRUE(b.as() != nullptr); + EXPECT_TRUE(b.as() != nullptr); + + EXPECT_TRUE(c.as() == nullptr); + EXPECT_TRUE(c.as() == nullptr); + EXPECT_TRUE(c.as() == nullptr); + + EXPECT_EQ(a.as()->value, 10); + EXPECT_EQ(b.as()->value, 20); +} + +TEST(Object, CAPIAccessor) { + ObjectRef a = TInt(10); + TVMFFIObjectHandle obj = details::ObjectUnsafe::RawObjectPtrFromObjectRef(a); + int32_t type_index = TVMFFIObjectGetTypeIndex(obj); + EXPECT_EQ(type_index, TIntObj::RuntimeTypeIndex()); +} +} // namespace diff --git a/ffi/tests/cpp/test_optional.cc b/ffi/tests/cpp/test_optional.cc new file mode 100644 index 000000000000..256a7da8b42c --- /dev/null +++ b/ffi/tests/cpp/test_optional.cc @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include +#include + +#include "./testing_object.h" + +namespace { + +using namespace tvm::ffi; +using namespace tvm::ffi::testing; + +TEST(Optional, TInt) { + Optional x; + Optional y = TInt(11); + static_assert(sizeof(Optional) == sizeof(ObjectRef)); + + EXPECT_TRUE(!x.has_value()); + EXPECT_EQ(x.value_or(TInt(12))->value, 12); + + EXPECT_TRUE(y.has_value()); + EXPECT_EQ(y.value_or(TInt(12))->value, 11); + + Any z_any = std::move(y); + EXPECT_TRUE(z_any != nullptr); + EXPECT_EQ((z_any.cast())->value, 11); + EXPECT_TRUE(!y.has_value()); + + // move from any to optional + auto y2 = std::move(z_any).cast>(); + EXPECT_EQ(y2.use_count(), 1); + EXPECT_TRUE(y2.has_value()); + EXPECT_EQ(y2.value_or(TInt(12))->value, 11); +} + +TEST(Optional, double) { + Optional x; + Optional y = 11.0; + static_assert(sizeof(Optional) > sizeof(ObjectRef)); + + EXPECT_TRUE(!x.has_value()); + EXPECT_EQ(x.value_or(12), 12); + EXPECT_TRUE(x != 12); + + EXPECT_TRUE(y.has_value()); + EXPECT_EQ(y.value_or(12), 11); + EXPECT_TRUE(y == 11); + EXPECT_TRUE(y != 12); +} + +TEST(Optional, AnyConvert_int) { + Optional opt_v0 = 1; + EXPECT_EQ(opt_v0.value(), 1); + EXPECT_TRUE(opt_v0.has_value()); + + AnyView view0 = opt_v0; + EXPECT_EQ(view0.cast(), 1); + + Any any1; + auto opt_v1 = std::move(any1).cast>(); + EXPECT_TRUE(!opt_v1.has_value()); + Optional opt_v2 = 11; + Any any2 = std::move(opt_v2); + EXPECT_EQ(any2.cast(), 11); +} + +TEST(Optional, AnyConvert_Array) { + AnyView view0; + Array> arr_nested = {{}, {TInt(1), TFloat(2)}}; + view0 = arr_nested; + + auto opt_arr = view0.cast>>>(); + EXPECT_EQ(arr_nested.use_count(), 2); + + auto arr1 = view0.cast>>>(); + EXPECT_EQ(arr_nested.use_count(), 3); + EXPECT_EQ(arr1.value()[1][1].as()->value, 2); + + Any any1; + auto arr2 = any1.cast>>>(); + EXPECT_TRUE(!arr2.has_value()); + + EXPECT_THROW( + { + try { + [[maybe_unused]] auto arr2 = view0.cast>>>(); + } catch (const Error& error) { + EXPECT_EQ(error.kind(), "TypeError"); + std::string what = error.what(); + std::cout << what << std::endl; + EXPECT_NE(what.find("to `Optional>>`"), std::string::npos); + throw; + } + }, + ::tvm::ffi::Error); +} + +TEST(Optional, OptionalOfOptional) { + // testcase of optional + Optional> opt_opt_int; + EXPECT_TRUE(!opt_opt_int.has_value()); + + Optional> opt_opt_int2 = Optional(std::nullopt); + EXPECT_TRUE(opt_opt_int2.has_value()); + EXPECT_TRUE(!opt_opt_int2.value().has_value()); + + // Optional> + Optional> opt_opt_tint; + EXPECT_TRUE(!opt_opt_tint.has_value()); + + Optional> opt_opt_tint2 = Optional(std::nullopt); + EXPECT_TRUE(opt_opt_tint2.has_value()); + EXPECT_TRUE(!opt_opt_tint2.value().has_value()); + opt_opt_tint2 = std::nullopt; + EXPECT_TRUE(!opt_opt_tint2.has_value()); + + Optional> opt_opt_tint3 = Optional(TInt(42)); + EXPECT_TRUE(opt_opt_tint3.has_value()); + EXPECT_TRUE(opt_opt_tint3.value().has_value()); + EXPECT_EQ(opt_opt_tint3.value().value()->value, 42); +} + +TEST(Optional, ValueMove) { + Optional y = TInt(11); + TInt x = std::move(y).value(); + EXPECT_TRUE(!y.has_value()); + EXPECT_EQ(x->value, 11); + + Optional opt_tint = TInt(21); + EXPECT_TRUE(opt_tint.has_value()); + EXPECT_EQ((*opt_tint)->value, 21); + + TInt moved_tint = *std::move(opt_tint); + EXPECT_EQ(moved_tint->value, 21); + EXPECT_TRUE(!opt_tint.has_value()); +} + +TEST(Optional, OptionalInArray) { + // This pattern plus iteration may cause memory leak + // this is because arr[0] returns a temporary object + // and further call arr[0].value() may return a reference to + // the temporary object + Array>> arr = {Array({TInt(0), TInt(1)})}; + int counter = 0; + + for (const auto& x : arr[0].value()) { + EXPECT_EQ(x->value, counter++); + } + + Any any = arr; + auto opt_arr = any.cast>>>(); + EXPECT_EQ(opt_arr[0].value()[0]->value, 0); +} +} // namespace diff --git a/ffi/tests/cpp/test_reflection.cc b/ffi/tests/cpp/test_reflection.cc new file mode 100644 index 000000000000..fec167d257e3 --- /dev/null +++ b/ffi/tests/cpp/test_reflection.cc @@ -0,0 +1,52 @@ + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include + +#include "./testing_object.h" + +namespace { + +using namespace tvm::ffi; +using namespace tvm::ffi::testing; + +struct A : public Object { + int64_t x; + int64_t y; +}; + +TEST(Reflection, GetFieldByteOffset) { + EXPECT_EQ(details::GetFieldByteOffsetToObject(&A::x), sizeof(TVMFFIObject)); + EXPECT_EQ(details::GetFieldByteOffsetToObject(&A::y), 8 + sizeof(TVMFFIObject)); + EXPECT_EQ(details::GetFieldByteOffsetToObject(&TIntObj::value), sizeof(TVMFFIObject)); +} + +TEST(Reflection, FieldGetter) { + ObjectRef a = TInt(10); + details::ReflectionFieldGetter getter(details::GetReflectionFieldInfo("test.Int", "value")); + EXPECT_EQ(getter(a).cast(), 10); + + ObjectRef b = TFloat(10.0); + details::ReflectionFieldGetter getter_float( + details::GetReflectionFieldInfo("test.Float", "value")); + EXPECT_EQ(getter_float(b).cast(), 10.0); +} +} // namespace diff --git a/ffi/tests/cpp/test_rvalue_ref.cc b/ffi/tests/cpp/test_rvalue_ref.cc new file mode 100644 index 000000000000..ac81208d48ba --- /dev/null +++ b/ffi/tests/cpp/test_rvalue_ref.cc @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include +#include + +#include "./testing_object.h" + +namespace { + +using namespace tvm::ffi; +using namespace tvm::ffi::testing; + +TEST(RValueRef, Basic) { + auto append = + Function::FromUnpacked([](RValueRef> ref, int val, bool is_unique) -> Array { + Array arr = *std::move(ref); + EXPECT_EQ(arr.unique(), is_unique); + arr.push_back(val); + return arr; + }); + auto a = append(RValueRef(Array({1, 2})), 3, true).cast>(); + EXPECT_EQ(a.size(), 3); + a = append(RValueRef(std::move(a)), 4, true).cast>(); + EXPECT_EQ(a.size(), 4); + // pass in lvalue instead, the append still will succeed but array will not be unique + a = append(a, 5, false).cast>(); + EXPECT_EQ(a.size(), 5); +} + +TEST(RValueRef, ParamChecking) { + // try decution + Function fadd1 = Function::FromUnpacked([](TInt a) -> int64_t { return a->value + 1; }); + + // convert that triggers error + EXPECT_THROW( + { + try { + fadd1(RValueRef(TInt(1))); + } catch (const Error& error) { + EXPECT_EQ(error.kind(), "TypeError"); + EXPECT_EQ(error.message(), + "Mismatched type on argument #0 when calling: `(0: test.Int) -> int`. " + "Expected `test.Int` but got `ObjectRValueRef`"); + throw; + } + }, + ::tvm::ffi::Error); + + Function fadd2 = Function::FromUnpacked([](RValueRef> a) -> int { + Array arr = *std::move(a); + return arr[0] + 1; + }); + + // convert that triggers error + EXPECT_THROW( + { + try { + fadd2(RValueRef(Array({1, 2.2}))); + } catch (const Error& error) { + EXPECT_EQ(error.kind(), "TypeError"); + EXPECT_EQ( + error.message(), + "Mismatched type on argument #0 when calling: `(0: RValueRef>) -> int`. " + "Expected `RValueRef>` but got `RValueRef`"); + throw; + } + }, + ::tvm::ffi::Error); + // triggered a rvalue based conversion + Function func3 = Function::FromUnpacked([](RValueRef a) -> String { + TPrimExpr expr = *std::move(a); + return expr->dtype; + }); + EXPECT_EQ(func3(RValueRef(String("int32"))).cast(), "int32"); + // triggered a lvalue based conversion + EXPECT_EQ(func3(String("int32")).cast(), "int32"); +} +} // namespace diff --git a/ffi/tests/cpp/test_shape.cc b/ffi/tests/cpp/test_shape.cc new file mode 100644 index 000000000000..0ccba7820ad7 --- /dev/null +++ b/ffi/tests/cpp/test_shape.cc @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include + +namespace { + +using namespace tvm::ffi; + +TEST(Shape, Basic) { + Shape shape = Shape({1, 2, 3}); + EXPECT_EQ(shape.size(), 3); + EXPECT_EQ(shape[0], 1); + EXPECT_EQ(shape[1], 2); + EXPECT_EQ(shape[2], 3); + + Shape shape2 = Shape(Array({4, 5, 6, 7})); + EXPECT_EQ(shape2.size(), 4); + EXPECT_EQ(shape2[0], 4); + EXPECT_EQ(shape2[1], 5); + EXPECT_EQ(shape2[2], 6); + EXPECT_EQ(shape2[3], 7); + + std::vector vec = {8, 9, 10}; + Shape shape3 = Shape(std::move(vec)); + EXPECT_EQ(shape3.size(), 3); + EXPECT_EQ(shape3[0], 8); + EXPECT_EQ(shape3[1], 9); + EXPECT_EQ(shape3[2], 10); + EXPECT_EQ(shape3.Product(), 8 * 9 * 10); + + Shape shape4 = Shape(); + EXPECT_EQ(shape4.size(), 0); + EXPECT_EQ(shape4.Product(), 1); +} + +TEST(Shape, AnyConvert) { + Shape shape0 = Shape({1, 2, 3}); + Any any0 = shape0; + + auto shape1 = any0.cast(); + EXPECT_EQ(shape1.size(), 3); + EXPECT_EQ(shape1[0], 1); + EXPECT_EQ(shape1[1], 2); + EXPECT_EQ(shape1[2], 3); + + Array arr({1, 2}); + AnyView any_view0 = arr; + auto shape2 = any_view0.cast(); + EXPECT_EQ(shape2.size(), 2); + EXPECT_EQ(shape2[0], 1); + EXPECT_EQ(shape2[1], 2); +} + +} // namespace diff --git a/ffi/tests/cpp/test_string.cc b/ffi/tests/cpp/test_string.cc new file mode 100644 index 000000000000..847ed6f9559c --- /dev/null +++ b/ffi/tests/cpp/test_string.cc @@ -0,0 +1,371 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include + +namespace { + +using namespace tvm::ffi; + +TEST(String, MoveFromStd) { + using namespace std; + string source = "this is a string"; + string expect = source; + String s(std::move(source)); + string copy = (string)s; + EXPECT_EQ(copy, expect); + EXPECT_EQ(source.size(), 0); +} + +TEST(String, CopyFromStd) { + using namespace std; + string source = "this is a string"; + string expect = source; + String s{source}; + string copy = (string)s; + EXPECT_EQ(copy, expect); + EXPECT_EQ(source.size(), expect.size()); +} + +TEST(String, Assignment) { + using namespace std; + String s{string{"hello"}}; + s = string{"world"}; + EXPECT_EQ(s == "world", true); + string s2{"world2"}; + s = std::move(s2); + EXPECT_EQ(s == "world2", true); + + ObjectRef r; + r = String("hello"); + EXPECT_EQ(r.defined(), true); +} + +TEST(String, empty) { + using namespace std; + String s{"hello"}; + EXPECT_EQ(s.empty(), false); + s = std::string(""); + EXPECT_EQ(s.empty(), true); +} + +TEST(String, Comparisons) { + using namespace std; + string source = "a string"; + string mismatch = "a string but longer"; + String s{"a string"}; + String m{mismatch}; + + EXPECT_EQ("a str" >= s, false); + EXPECT_EQ(s == source, true); + EXPECT_EQ(s == mismatch, false); + EXPECT_EQ(s == source.data(), true); + EXPECT_EQ(s == mismatch.data(), false); + + EXPECT_EQ(s < m, source < mismatch); + EXPECT_EQ(s > m, source > mismatch); + EXPECT_EQ(s <= m, source <= mismatch); + EXPECT_EQ(s >= m, source >= mismatch); + EXPECT_EQ(s == m, source == mismatch); + EXPECT_EQ(s != m, source != mismatch); + + EXPECT_EQ(m < s, mismatch < source); + EXPECT_EQ(m > s, mismatch > source); + EXPECT_EQ(m <= s, mismatch <= source); + EXPECT_EQ(m >= s, mismatch >= source); + EXPECT_EQ(m == s, mismatch == source); + EXPECT_EQ(m != s, mismatch != source); +} + +// Check '\0' handling +TEST(String, null_byte_handling) { + using namespace std; + // Ensure string still compares equal if it contains '\0'. + string v1 = "hello world"; + size_t v1_size = v1.size(); + v1[5] = '\0'; + EXPECT_EQ(v1[5], '\0'); + EXPECT_EQ(v1.size(), v1_size); + String str_v1{v1}; + EXPECT_EQ(str_v1.compare(v1), 0); + EXPECT_EQ(str_v1.size(), v1_size); + + // Ensure bytes after '\0' are taken into account for mismatches. + string v2 = "aaa one"; + string v3 = "aaa two"; + v2[3] = '\0'; + v3[3] = '\0'; + String str_v2{v2}; + String str_v3{v3}; + EXPECT_EQ(str_v2.compare(str_v3), -1); + EXPECT_EQ(str_v2.size(), 7); + // strcmp won't be able to detect the mismatch + EXPECT_EQ(strcmp(v2.data(), v3.data()), 0); + // string::compare can handle \0 since it knows size + EXPECT_LT(v2.compare(v3), 0); + + // If there is mismatch before '\0', should still handle it. + string v4 = "acc one"; + string v5 = "abb two"; + v4[3] = '\0'; + v5[3] = '\0'; + String str_v4{v4}; + String str_v5{v5}; + EXPECT_GT(str_v4.compare(str_v5), 0); + EXPECT_EQ(str_v4.size(), 7); + // strcmp is able to detect the mismatch + EXPECT_GT(strcmp(v4.data(), v5.data()), 0); + // string::compare can handle \0 since it knows size + EXPECT_GT(v4.compare(v5), 0); +} + +TEST(String, compare_same_memory_region_different_size) { + using namespace std; + string source = "a string"; + String str_source{source}; + char* memory = const_cast(str_source.data()); + EXPECT_EQ(str_source.compare(memory), 0); + // This changes the string size + memory[2] = '\0'; + // memory is logically shorter now + EXPECT_GT(str_source.compare(memory), 0); +} + +TEST(String, compare) { + using namespace std; + constexpr auto mismatch1_cstr = "a string but longer"; + string source = "a string"; + string mismatch1 = mismatch1_cstr; + string mismatch2 = "a strin"; + string mismatch3 = "a b"; + string mismatch4 = "a t"; + String str_source{source}; + String str_mismatch1{mismatch1_cstr}; + String str_mismatch2{mismatch2}; + String str_mismatch3{mismatch3}; + String str_mismatch4{mismatch4}; + + // compare with string + EXPECT_EQ(str_source.compare(source), 0); + EXPECT_TRUE(str_source == source); + EXPECT_TRUE(source == str_source); + EXPECT_TRUE(str_source <= source); + EXPECT_TRUE(source <= str_source); + EXPECT_TRUE(str_source >= source); + EXPECT_TRUE(source >= str_source); + EXPECT_LT(str_source.compare(mismatch1), 0); + EXPECT_TRUE(str_source < mismatch1); + EXPECT_TRUE(mismatch1 != str_source); + EXPECT_GT(str_source.compare(mismatch2), 0); + EXPECT_TRUE(str_source > mismatch2); + EXPECT_TRUE(mismatch2 < str_source); + EXPECT_GT(str_source.compare(mismatch3), 0); + EXPECT_TRUE(str_source > mismatch3); + EXPECT_LT(str_source.compare(mismatch4), 0); + EXPECT_TRUE(str_source < mismatch4); + EXPECT_TRUE(mismatch4 > str_source); + + // compare with char* + EXPECT_EQ(str_source.compare(source.data()), 0); + EXPECT_TRUE(str_source == source.data()); + EXPECT_TRUE(source.data() == str_source); + EXPECT_TRUE(str_source <= source.data()); + EXPECT_TRUE(source <= str_source.data()); + EXPECT_TRUE(str_source >= source.data()); + EXPECT_TRUE(source >= str_source.data()); + EXPECT_LT(str_source.compare(mismatch1.data()), 0); + EXPECT_TRUE(str_source < mismatch1.data()); + EXPECT_TRUE(str_source != mismatch1.data()); + EXPECT_TRUE(mismatch1.data() != str_source); + EXPECT_GT(str_source.compare(mismatch2.data()), 0); + EXPECT_TRUE(str_source > mismatch2.data()); + EXPECT_TRUE(mismatch2.data() < str_source); + EXPECT_GT(str_source.compare(mismatch3.data()), 0); + EXPECT_TRUE(str_source > mismatch3.data()); + EXPECT_LT(str_source.compare(mismatch4.data()), 0); + EXPECT_TRUE(str_source < mismatch4.data()); + EXPECT_TRUE(mismatch4.data() > str_source); + + // compare with String + EXPECT_LT(str_source.compare(str_mismatch1), 0); + EXPECT_TRUE(str_source < str_mismatch1); + EXPECT_GT(str_source.compare(str_mismatch2), 0); + EXPECT_TRUE(str_source > str_mismatch2); + EXPECT_GT(str_source.compare(str_mismatch3), 0); + EXPECT_TRUE(str_source > str_mismatch3); + EXPECT_LT(str_source.compare(str_mismatch4), 0); + EXPECT_TRUE(str_source < str_mismatch4); +} + +TEST(String, c_str) { + using namespace std; + string source = "this is a string"; + string mismatch = "mismatch"; + String s{source}; + + EXPECT_EQ(std::strcmp(s.c_str(), source.data()), 0); + EXPECT_NE(std::strcmp(s.c_str(), mismatch.data()), 0); +} + +TEST(String, hash) { + using namespace std; + string source = "this is a string"; + String s{source}; + std::hash()(s); + + std::unordered_map map; + String k1{string{"k1"}}; + string v1{"v1"}; + String k2{string{"k2"}}; + string v2{"v2"}; + map[k1] = v1; + map[k2] = v2; + + EXPECT_EQ(map[k1], v1); + EXPECT_EQ(map[k2], v2); +} + +TEST(String, Cast) { + using namespace std; + string source = "this is a string"; + String s{source}; + ObjectRef r = s; + String s2 = Downcast(r); +} + +TEST(String, Concat) { + String s1("hello"); + String s2("world"); + std::string s3("world"); + String res1 = s1 + s2; + String res2 = s1 + s3; + String res3 = s3 + s1; + String res4 = s1 + "world"; + String res5 = "world" + s1; + + EXPECT_EQ(res1.compare("helloworld"), 0); + EXPECT_EQ(res2.compare("helloworld"), 0); + EXPECT_EQ(res3.compare("worldhello"), 0); + EXPECT_EQ(res4.compare("helloworld"), 0); + EXPECT_EQ(res5.compare("worldhello"), 0); +} + +TEST(String, Any) { + // test anyview promotion to any + AnyView view = "hello"; + + Any b = view; + EXPECT_EQ(b.type_index(), TypeIndex::kTVMFFIStr); + EXPECT_EQ(b.as().value(), "hello"); + EXPECT_EQ(b.as().value(), "hello"); + + std::string s_world = "world"; + view = s_world; + EXPECT_EQ(view.as().value(), "world"); + + String s{"hello"}; + Any a = s; + EXPECT_EQ(a.type_index(), TypeIndex::kTVMFFIStr); + EXPECT_EQ(a.as().value(), "hello"); + EXPECT_EQ(a.as().value(), "hello"); + + Any c = "helloworld"; + EXPECT_EQ(c.type_index(), TypeIndex::kTVMFFIStr); + EXPECT_EQ(c.as().value(), "helloworld"); + EXPECT_EQ(c.as().value(), "helloworld"); +} + +TEST(String, Bytes) { + // explicitly test zero element + std::string s = {'\0', 'a', 'b', 'c'}; + Bytes b = s; + EXPECT_EQ(b.size(), 4); + EXPECT_EQ(b.operator std::string(), s); + + TVMFFIByteArray arr{s.data(), static_cast(s.size())}; + Bytes b2 = arr; + EXPECT_EQ(b2.size(), 4); + EXPECT_EQ(b2.operator std::string(), s); +} + +TEST(String, BytesAny) { + std::string s = {'\0', 'a', 'b', 'c'}; + TVMFFIByteArray arr{s.data(), static_cast(s.size())}; + + AnyView view = &arr; + EXPECT_EQ(view.type_index(), TypeIndex::kTVMFFIByteArrayPtr); + EXPECT_EQ(view.as().value().operator std::string(), s); + + Any b = view; + EXPECT_EQ(b.type_index(), TypeIndex::kTVMFFIBytes); + + EXPECT_EQ(b.as().value().operator std::string(), s); + EXPECT_EQ(b.as().value(), s); +} + +TEST(String, StdString) { + std::string s1 = "test_string"; + AnyView view1 = s1; + EXPECT_EQ(view1.type_index(), TypeIndex::kTVMFFIRawStr); + EXPECT_EQ(view1.as().value(), s1); + + TVMFFIByteArray arr1{s1.data(), static_cast(s1.size())}; + AnyView view2 = &arr1; + EXPECT_EQ(view2.type_index(), TypeIndex::kTVMFFIByteArrayPtr); + EXPECT_EQ(view2.as().value(), s1); + + Bytes bytes1 = s1; + AnyView view3 = bytes1; + EXPECT_EQ(view3.type_index(), TypeIndex::kTVMFFIBytes); + EXPECT_EQ(view3.as().value(), s1); + + String string1 = s1; + AnyView view4 = string1; + EXPECT_EQ(view4.type_index(), TypeIndex::kTVMFFIStr); + EXPECT_EQ(view4.as().value(), s1); + + // Test with Any + Any any1 = s1; + EXPECT_EQ(any1.type_index(), TypeIndex::kTVMFFIStr); + EXPECT_EQ(any1.as().value(), s1); + + Any any2 = &arr1; + EXPECT_EQ(any2.type_index(), TypeIndex::kTVMFFIBytes); + EXPECT_EQ(any2.as().value(), s1); + + Any any3 = bytes1; + EXPECT_EQ(any3.type_index(), TypeIndex::kTVMFFIBytes); + EXPECT_EQ(any3.as().value(), s1); + + Any any4 = string1; + EXPECT_EQ(any4.type_index(), TypeIndex::kTVMFFIStr); + EXPECT_EQ(any4.as().value(), s1); +} + +TEST(String, CAPIAccessor) { + using namespace std; + String s{"hello"}; + TVMFFIObjectHandle obj = details::ObjectUnsafe::RawObjectPtrFromObjectRef(s); + TVMFFIByteArray* arr = TVMFFIBytesGetByteArrayPtr(obj); + EXPECT_EQ(arr->size, 5); + EXPECT_EQ(std::string(arr->data, arr->size), "hello"); +} +} // namespace diff --git a/ffi/tests/cpp/test_tuple.cc b/ffi/tests/cpp/test_tuple.cc new file mode 100644 index 000000000000..42c8e6aacc2b --- /dev/null +++ b/ffi/tests/cpp/test_tuple.cc @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include + +#include "./testing_object.h" + +namespace { + +using namespace tvm::ffi; +using namespace tvm::ffi::testing; + +TEST(Tuple, Basic) { + Tuple tuple0(1, 2.0f); + EXPECT_EQ(tuple0.get<0>(), 1); + EXPECT_EQ(tuple0.get<1>(), 2.0f); + + Tuple tuple1 = tuple0; + EXPECT_EQ(tuple0.use_count(), 2); + + // test copy on write + tuple1.Set<0>(3); + EXPECT_EQ(tuple0.get<0>(), 1); + EXPECT_EQ(tuple1.get<0>(), 3); + + EXPECT_EQ(tuple0.use_count(), 1); + EXPECT_EQ(tuple1.use_count(), 1); + + // copy on write not triggered because + // tuple1 is unique. + tuple1.Set<1>(4); + EXPECT_EQ(tuple1.get<1>(), 4.0f); + EXPECT_EQ(tuple1.use_count(), 1); + + // default state + Tuple tuple2; + EXPECT_EQ(tuple2.use_count(), 1); + tuple2.Set<0>(1); + tuple2.Set<1>(2.0f); + EXPECT_EQ(tuple2.get<0>(), 1); + EXPECT_EQ(tuple2.get<1>(), 2.0f); + + // tuple of object and primitive + Tuple tuple3(1, 2); + EXPECT_EQ(tuple3.get<0>()->value, 1); + EXPECT_EQ(tuple3.get<1>(), 2); + tuple3.Set<0>(4); + EXPECT_EQ(tuple3.get<0>()->value, 4); +} + +TEST(Tuple, AnyConvert) { + Tuple tuple0(1, 2); + AnyView view0 = tuple0; + Array arr0 = view0.as>().value(); + EXPECT_EQ(arr0.size(), 2); + EXPECT_EQ(arr0[0].as().value(), 1); + EXPECT_EQ(arr0[1].as().value()->value, 2); + + // directly reuse the underlying storage. + auto tuple1 = view0.cast>(); + EXPECT_TRUE(tuple0.same_as(tuple1)); + + Any any0 = view0; + // trigger a copy due to implict conversion + auto tuple2 = any0.cast>(); + EXPECT_TRUE(!tuple0.same_as(tuple2)); + EXPECT_EQ(tuple2.get<0>()->value, 1); + EXPECT_EQ(tuple2.get<1>()->value, 2); +} + +TEST(Tuple, FromUnpacked) { + // try decution + Function fadd1 = Function::FromUnpacked([](const Tuple& a) -> int { + return a.get<0>() + static_cast(a.get<1>()->value); + }); + int b = fadd1(Tuple(1, 2)).cast(); + EXPECT_EQ(b, 3); + + int c = fadd1(Array({1, 2})).cast(); + EXPECT_EQ(c, 3); + + // convert that triggers error + EXPECT_THROW( + { + try { + fadd1(Array({1.1, 2})); + } catch (const Error& error) { + EXPECT_EQ(error.kind(), "TypeError"); + EXPECT_EQ(error.message(), + "Mismatched type on argument #0 when calling: `(0: Tuple) -> int`. " + "Expected `Tuple` but got `Array[index 0: float]`"); + throw; + } + }, + ::tvm::ffi::Error); + + EXPECT_THROW( + { + try { + fadd1(Array({1.1})); + } catch (const Error& error) { + EXPECT_EQ(error.kind(), "TypeError"); + EXPECT_EQ(error.message(), + "Mismatched type on argument #0 when calling: `(0: Tuple) -> int`. " + "Expected `Tuple` but got `Array[size=1]`"); + throw; + } + }, + ::tvm::ffi::Error); +} + +TEST(Tuple, Upcast) { + Tuple t0(1, 2.0f); + Tuple t1 = t0; + EXPECT_EQ(t1.get<0>().cast(), 1); + EXPECT_EQ(t1.get<1>().cast(), 2.0f); + static_assert(details::type_contains_v, Tuple>); + static_assert(details::type_contains_v, Tuple>); + static_assert(details::type_contains_v, Tuple>); +} +} // namespace diff --git a/ffi/tests/cpp/test_variant.cc b/ffi/tests/cpp/test_variant.cc new file mode 100644 index 000000000000..94cbcd491a6e --- /dev/null +++ b/ffi/tests/cpp/test_variant.cc @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include +#include + +#include "./testing_object.h" + +namespace { + +using namespace tvm::ffi; +using namespace tvm::ffi::testing; + +TEST(Variant, Basic) { + Variant v1 = 1; + EXPECT_EQ(v1.get(), 1); + EXPECT_EQ(v1.as().value(), 1.0f); + + Variant v2 = 2.0f; + EXPECT_EQ(v2.get(), 2.0f); + v2 = v1; + EXPECT_EQ(v2.get(), 1); +} + +TEST(Variant, AnyConvert) { + Variant v = 1; + AnyView view0 = v; + EXPECT_EQ(view0.as().value(), 1); + + // implicit convert to variant + Any any0 = 1; + auto v1 = any0.cast>>(); + EXPECT_EQ(v1.get()->value, 1); + + // move from any to variant + Variant v2 = TInt(1); + Any any1 = std::move(v2); + auto v3 = std::move(any1).cast>(); + auto v4 = std::move(v3).get(); + EXPECT_EQ(v4->value, 1); + EXPECT_EQ(v4.use_count(), 1); +} + +TEST(Variant, ObjectPtrHashEqual) { + TInt x = TInt(1); + TFloat y = TFloat(1.0f); + + Variant v0 = x; + Variant v1 = y; + Variant v2 = v1; + + EXPECT_EQ(ObjectPtrHash()(v0), ObjectPtrHash()(x)); + EXPECT_TRUE(!ObjectPtrEqual()(v0, v1)); + EXPECT_TRUE(!ObjectPtrEqual()(v0, v2)); +} + +TEST(Variant, FromUnpacked) { + // try decution + Function fadd1 = Function::FromUnpacked([](const Variant& a) -> int64_t { + if (auto opt_int = a.as()) { + return opt_int.value() + 1; + } else { + return a.get()->value + 1; + } + }); + int b = fadd1(1).cast(); + EXPECT_EQ(b, 2); + + // convert that triggers error + EXPECT_THROW( + { + try { + fadd1(1.1); + } catch (const Error& error) { + EXPECT_EQ(error.kind(), "TypeError"); + EXPECT_EQ( + error.message(), + "Mismatched type on argument #0 when calling: `(0: Variant) -> int`. " + "Expected `Variant` but got `float`"); + throw; + } + }, + ::tvm::ffi::Error); + + Function fadd2 = Function::FromUnpacked([](const Array>& a) -> int64_t { + if (auto opt_int = a[0].as()) { + return opt_int.value() + 1; + } else { + return a[0].get()->value + 1; + } + }); + int c = fadd2(Array({1, 2})).cast(); + EXPECT_EQ(c, 2); + + // convert that triggers error + EXPECT_THROW( + { + try { + fadd2(Array({1, 1.1})); + } catch (const Error& error) { + EXPECT_EQ(error.kind(), "TypeError"); + EXPECT_EQ(error.message(), + "Mismatched type on argument #0 when calling: `(0: Array>) -> int`. " + "Expected `Array>` but got `Array[index 1: float]`"); + throw; + } + }, + ::tvm::ffi::Error); +} + +TEST(Variant, Upcast) { + Array a0 = {1, 2, 3}; + static_assert(details::type_contains_v>, Array>); + Array> a1 = a0; + EXPECT_EQ(a1[0].get(), 1); +} + +} // namespace diff --git a/ffi/tests/cpp/testing_object.h b/ffi/tests/cpp/testing_object.h new file mode 100644 index 000000000000..69a91efc46d0 --- /dev/null +++ b/ffi/tests/cpp/testing_object.h @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_FFI_TESTING_OBJECT_H_ +#define TVM_FFI_TESTING_OBJECT_H_ + +#include +#include +#include +#include + +namespace tvm { +namespace ffi { +namespace testing { + +// We deliberately pad extra +// in the header to test cases +// where the object subclass address +// do not align with the base object address +// not handling properly will cause buffer overflow +class BasePad { + public: + int64_t extra[4]; +}; + +class TNumberObj : public BasePad, public Object { + public: + // declare as one slot, with float as overflow + static constexpr uint32_t _type_child_slots = 1; + static constexpr const char* _type_key = "test.Number"; + TVM_FFI_DECLARE_BASE_OBJECT_INFO(TNumberObj, Object); +}; + +class TNumber : public ObjectRef { + public: + TVM_FFI_DEFINE_OBJECT_REF_METHODS(TNumber, ObjectRef, TNumberObj); +}; + +class TIntObj : public TNumberObj { + public: + int64_t value; + + TIntObj(int64_t value) : value(value) {} + + int64_t GetValue() const { return value; } + + static constexpr const char* _type_key = "test.Int"; + + TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TIntObj, TNumberObj); +}; + +TVM_FFI_REFLECTION_DEF(TIntObj).def_readonly("value", &TIntObj::value); + +class TInt : public TNumber { + public: + explicit TInt(int64_t value) { data_ = make_object(value); } + + TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TInt, TNumber, TIntObj); +}; + +class TFloatObj : public TNumberObj { + public: + double value; + + TFloatObj(double value) : value(value) {} + + static constexpr const char* _type_key = "test.Float"; + TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TFloatObj, TNumberObj); +}; + +TVM_FFI_REFLECTION_DEF(TFloatObj).def_readonly("value", &TFloatObj::value); + +class TFloat : public TNumber { + public: + explicit TFloat(double value) { data_ = make_object(value); } + + TVM_FFI_DEFINE_OBJECT_REF_METHODS(TFloat, TNumber, TFloatObj); +}; + +// TPrimExpr is used for testing FallbackTraits +class TPrimExprObj : public Object { + public: + std::string dtype; + double value; + + TPrimExprObj(std::string dtype, double value) : dtype(dtype), value(value) {} + + static constexpr const char* _type_key = "test.PrimExpr"; + TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TPrimExprObj, Object); +}; + +class TPrimExpr : public ObjectRef { + public: + explicit TPrimExpr(std::string dtype, double value) { + data_ = make_object(dtype, value); + } + + TVM_FFI_DEFINE_OBJECT_REF_METHODS(TPrimExpr, ObjectRef, TPrimExprObj); +}; +} // namespace testing + +template <> +inline constexpr bool use_default_type_traits_v = true; + +template <> +struct TypeTraits + : public ObjectRefWithFallbackTraitsBase { + static TVM_FFI_INLINE testing::TPrimExpr ConvertFallbackValue(StrictBool value) { + return testing::TPrimExpr("bool", static_cast(value)); + } + + static TVM_FFI_INLINE testing::TPrimExpr ConvertFallbackValue(int64_t value) { + return testing::TPrimExpr("int64", static_cast(value)); + } + + static TVM_FFI_INLINE testing::TPrimExpr ConvertFallbackValue(double value) { + return testing::TPrimExpr("float32", static_cast(value)); + } + // hack into the dtype to store string + static TVM_FFI_INLINE testing::TPrimExpr ConvertFallbackValue(String value) { + return testing::TPrimExpr(value, 0); + } +}; + +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_TESTING_OBJECT_H_ diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index d038d5f59a5f..eef32b4773f0 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -48,6 +48,7 @@ #include #include #include +#include #include #include @@ -63,10 +64,10 @@ namespace tvm { * \param ClassName The name of the class. * \param TypeKey The type key to be used by the TVM node system. */ -#define TVM_DECLARE_ATTRS(ClassName, TypeKey) \ - static constexpr const char* _type_key = TypeKey; \ - TVM_DECLARE_FINAL_OBJECT_INFO(ClassName, ::tvm::BaseAttrsNode) \ - template \ +#define TVM_DECLARE_ATTRS(ClassName, TypeKey) \ + static constexpr const char* _type_key = TypeKey; \ + TVM_DECLARE_FINAL_OBJECT_INFO(ClassName, ::tvm::BaseAttrsNode); \ + template \ void _tvm_VisitAttrs(FVisit& _tvm_fvisit) // NOLINT(*) /*! @@ -97,7 +98,7 @@ struct AttrError : public Error { * \brief constructor * \param msg error message */ - explicit AttrError(std::string msg) : Error("AttributeError:" + msg) {} + explicit AttrError(std::string msg) : Error("AttributeError", msg, TVM_FFI_TRACEBACK_HERE) {} }; /*! @@ -201,7 +202,7 @@ class Attrs : public ObjectRef { class DictAttrsNode : public BaseAttrsNode { public: /*! \brief internal attrs map */ - Map dict; + Map dict; bool SEqualReduce(const DictAttrsNode* other, SEqualReducer equal) const { return equal(dict, other->dict); @@ -230,7 +231,7 @@ class DictAttrs : public Attrs { * \brief Consruct a Attrs backed by DictAttrsNode. * \param dict The attributes. */ - TVM_DLL explicit DictAttrs(Map dict = {}); + TVM_DLL explicit DictAttrs(Map dict = {}); // Utils for accessing attributes // This needs to be on DictAttrs, not DictAttrsNode because we return the default @@ -257,24 +258,12 @@ class DictAttrs : public Attrs { template Optional GetAttr( const std::string& attr_key, - Optional default_value = Optional(nullptr)) const { - static_assert(std::is_base_of::value, - "Can only call GetAttr with ObjectRef types."); + Optional default_value = Optional(std::nullopt)) const { if (!defined()) return default_value; const DictAttrsNode* node = this->as(); - auto it = node->dict.find(attr_key); if (it != node->dict.end()) { - // For backwards compatibility, return through TVMRetValue. - // This triggers any automatic conversions registered with - // PackedFuncValueConverter. Importantly, this allows use of - // `GetAttr` and `GetAttr` for properties that - // are stored internally as `runtime::Box` and - // `runtime::Box`. - TVMRetValue ret; - ret = (*it).second; - Optional obj = ret; - return obj; + return (*it).second.cast(); } else { return default_value; } @@ -320,7 +309,7 @@ template inline TAttrs AttrsWithDefaultValues() { static_assert(std::is_base_of::value, "Can only take attr nodes"); auto n = make_object(); - n->InitByPackedArgs(runtime::TVMArgs(nullptr, nullptr, 0), false); + n->InitByPackedArgs(ffi::PackedArgs(nullptr, 0), false); return TAttrs(n); } @@ -334,7 +323,7 @@ inline TAttrs AttrsWithDefaultValues() { * * \returns The new DictAttrs with updated attributes. */ -DictAttrs WithAttrs(DictAttrs attrs, Map new_attrs); +DictAttrs WithAttrs(DictAttrs attrs, Map new_attrs); /*! * \brief Copy the DictAttrs, but overrides a single attribute. @@ -347,9 +336,9 @@ DictAttrs WithAttrs(DictAttrs attrs, Map new_attrs); * * \returns The new DictAttrs with updated attributes. */ -DictAttrs WithAttr(DictAttrs attrs, String key, ObjectRef value); +DictAttrs WithAttr(DictAttrs attrs, String key, Any value); -inline DictAttrs WithAttr(DictAttrs attrs, const std::string& key, ObjectRef value) { +inline DictAttrs WithAttr(DictAttrs attrs, const std::string& key, Any value) { return WithAttr(std::move(attrs), String(key), std::move(value)); } @@ -392,7 +381,7 @@ DictAttrs WithoutAttr(DictAttrs attrs, const std::string& key); * \endcode */ template -inline TFunc WithAttr(TFunc input, const std::string& attr_key, ObjectRef attr_value) { +inline TFunc WithAttr(TFunc input, const std::string& attr_key, Any attr_value) { using TNode = typename TFunc::ContainerType; static_assert(TNode::_type_final, "Can only operate on the leaf nodes"); TNode* node = input.CopyOnWrite(); @@ -412,7 +401,7 @@ inline TFunc WithAttr(TFunc input, const std::string& attr_key, ObjectRef attr_v * \returns The new function or module with updated attributes. */ template -inline TFunc WithAttrs(TFunc input, Map attrs) { +inline TFunc WithAttrs(TFunc input, Map attrs) { using TNode = typename TFunc::ContainerType; static_assert(TNode::_type_final, "Can only operate on the leaf nodes"); TNode* node = input.CopyOnWrite(); @@ -578,7 +567,7 @@ struct AttrInitEntry { const T& val = *value_; if (begin > val) { std::ostringstream os; - os << type_key_ << "." << key_ << ": " + os << type_key_ << "." << key_ << "'s " << "value " << val << " is smaller than the lower bound " << begin; throw AttrError(os.str()); } @@ -590,7 +579,7 @@ struct AttrInitEntry { const T& val = *value_; if (val > end) { std::ostringstream os; - os << type_key_ << "." << key_ << ": " + os << type_key_ << "." << key_ << "'s " << "value " << val << " is bigger than the upper bound " << end; throw AttrError(os.str()); } @@ -609,41 +598,41 @@ struct AttrInitEntry { // Template function to allow smart conversion // from Expr types into the constants. template -inline void SetValue(T* ptr, const TVMArgValue& val) { - *ptr = val.operator T(); +inline void SetValue(T* ptr, const ffi::AnyView& val) { + *ptr = val.cast(); } template -inline void SetIntValue(T* ptr, const TVMArgValue& val) { - if (val.type_code() == kDLInt) { - *ptr = static_cast(val.value().v_int64); +inline void SetIntValue(T* ptr, const ffi::AnyView& val) { + if (auto opt_int = val.as()) { + *ptr = static_cast(opt_int.value()); } else { - IntImm expr = val; + IntImm expr = val.cast(); *ptr = static_cast(expr->value); } } // Workaround for GCC8.1 / GCC8.2 template <> -inline void SetValue(DataType* ptr, const TVMArgValue& val) { - *ptr = val.operator DataType(); +inline void SetValue(DataType* ptr, const ffi::AnyView& val) { + *ptr = DataType(val.cast()); } template <> -inline void SetValue(std::string* ptr, const TVMArgValue& val) { - if (String::CanConvertFrom(val)) { - *ptr = val.operator std::string(); +inline void SetValue(std::string* ptr, const ffi::AnyView& val) { + if (auto opt_str = val.as()) { + *ptr = opt_str.value(); } else { LOG(FATAL) << "Expect str"; } } template <> -inline void SetValue(double* ptr, const TVMArgValue& val) { - if (val.type_code() == kDLFloat || val.type_code() == kDLInt) { - *ptr = val.operator double(); +inline void SetValue(double* ptr, const ffi::AnyView& val) { + if (auto opt_double = val.as()) { + *ptr = opt_double.value(); } else { - ObjectRef expr = val; + ObjectRef expr = val.cast(); ICHECK(expr.defined()); if (const IntImmNode* op = expr.as()) { *ptr = static_cast(op->value); @@ -655,19 +644,19 @@ inline void SetValue(double* ptr, const TVMArgValue& val) { } } template <> -inline void SetValue(int* ptr, const TVMArgValue& val) { +inline void SetValue(int* ptr, const ffi::AnyView& val) { SetIntValue(ptr, val); } template <> -inline void SetValue(int64_t* ptr, const TVMArgValue& val) { +inline void SetValue(int64_t* ptr, const ffi::AnyView& val) { SetIntValue(ptr, val); } template <> -inline void SetValue(uint64_t* ptr, const TVMArgValue& val) { +inline void SetValue(uint64_t* ptr, const ffi::AnyView& val) { SetIntValue(ptr, val); } template <> -inline void SetValue(bool* ptr, const TVMArgValue& val) { +inline void SetValue(bool* ptr, const ffi::AnyView& val) { SetIntValue(ptr, val); } @@ -683,7 +672,7 @@ class AttrInitVisitor { template AttrInitEntry operator()(const char* key, T* value) { - TVMArgValue val; + ffi::AnyView val; AttrInitEntry opt; opt.type_key_ = type_key_; opt.key_ = key; @@ -729,12 +718,22 @@ struct TypeName { template <> struct TypeName { - static constexpr const char* value = "int64"; + static constexpr const char* value = "int"; +}; + +template <> +struct TypeName> { + static constexpr const char* value = "Optional[int]"; +}; + +template <> +struct TypeName> { + static constexpr const char* value = "Optional[float]"; }; template <> struct TypeName { - static constexpr const char* value = "uint64_t"; + static constexpr const char* value = "int"; }; template <> @@ -759,7 +758,7 @@ struct TypeName { template <> struct TypeName { - static constexpr const char* value = "double"; + static constexpr const char* value = "float"; }; class AttrDocEntry { @@ -886,10 +885,9 @@ class AttrsNode : public BaseAttrsNode { // applies two strategies to lookup if (args.size() < kLinearSearchBound) { // linear search. - auto ffind = [&args](const char* key, runtime::TVMArgValue* val) { + auto ffind = [&args](const char* key, ffi::AnyView* val) { for (int i = 0; i < args.size(); i += 2) { - ICHECK_EQ(args.type_codes[i], kTVMStr); - if (!std::strcmp(key, args.values[i].v_str)) { + if (!std::strcmp(key, args[i].cast())) { *val = args[i + 1]; return true; } @@ -903,10 +901,9 @@ class AttrsNode : public BaseAttrsNode { // construct a map then do lookup. std::unordered_map kwargs; for (int i = 0; i < args.size(); i += 2) { - ICHECK_EQ(args.type_codes[i], kTVMStr); - kwargs[args[i].operator std::string()] = args[i + 1]; + kwargs[args[i].cast()] = args[i + 1]; } - auto ffind = [&kwargs](const char* key, runtime::TVMArgValue* val) { + auto ffind = [&kwargs](const char* key, ffi::AnyView* val) { auto it = kwargs.find(key); if (it != kwargs.end()) { *val = it->second; @@ -922,11 +919,11 @@ class AttrsNode : public BaseAttrsNode { if (hit_count * 2 != args.size() && !allow_unknown) { for (int i = 0; i < args.size(); i += 2) { ::tvm::detail::AttrExistVisitor visitor; - visitor.key_ = args[i].operator std::string(); + visitor.key_ = args[i].cast(); self()->_tvm_VisitAttrs(visitor); if (!visitor.exist_) { std::ostringstream os; - os << DerivedType::_type_key << ": does not have field \'" << visitor.key_ + os << DerivedType::_type_key << " does not have field \'" << visitor.key_ << "\', Possible fields:\n"; os << "----------------\n"; this->PrintDocString(os); diff --git a/include/tvm/ir/env_func.h b/include/tvm/ir/env_func.h index dd93af6852fe..c44e102ccadd 100644 --- a/include/tvm/ir/env_func.h +++ b/include/tvm/ir/env_func.h @@ -139,8 +139,16 @@ class TypedEnvFunc : public ObjectRef { R operator()(Args... args) const { const EnvFuncNode* n = operator->(); ICHECK(n != nullptr); - return runtime::detail::typed_packed_call_dispatcher::run(n->func, - std::forward(args)...); + if constexpr (std::is_same_v) { + n->func(std::forward(args)...); + } else { + Any res = n->func(std::forward(args)...); + if constexpr (std::is_same_v) { + return res; + } else { + return std::move(res).cast(); + } + } } /*! \brief specify container node */ using ContainerType = EnvFuncNode; diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 53af26975648..a3defa592af6 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -56,6 +56,7 @@ class BaseExprNode : public Object { mutable Span span; static constexpr const char* _type_key = "BaseExpr"; + static constexpr const bool _type_has_method_visit_attrs = true; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; static constexpr const uint32_t _type_child_slots = 64; @@ -129,13 +130,51 @@ class PrimExpr : public BaseExpr { DataType dtype() const { return static_cast(get())->dtype; } TVM_DEFINE_OBJECT_REF_METHODS(PrimExpr, BaseExpr, PrimExprNode); +}; + +/*! + * \brief Base class for other IR constructs that can be converted to PrimExpr. + * This is useful for the FFI to convert the expressions to PrimExpr. + * \sa PrimExpr + */ +class PrimExprConvertibleNode : public Object { + public: + virtual ~PrimExprConvertibleNode() {} + virtual PrimExpr ToPrimExpr() const = 0; - private: - // Internal function for conversion. - friend struct runtime::PackedFuncValueConverter; - TVM_DLL static PrimExpr FromObject_(ObjectRef ref); + static constexpr const char* _type_key = "PrimExprConvertible"; + TVM_DECLARE_BASE_OBJECT_INFO(PrimExprConvertibleNode, Object); }; +/*! + * \brief Managed reference to PrimExprConvertibleNode. + * \sa PrimExprConvertibleNode + */ +class PrimExprConvertible : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(PrimExprConvertible, ObjectRef, PrimExprConvertibleNode); +}; + +namespace ffi { +// define automatic conversion from bool, int64_t, double, String to PrimExpr +// These functions are declared early to avoid circular dependency +template <> +inline constexpr bool use_default_type_traits_v = false; + +template <> +struct TypeTraits + : public ObjectRefWithFallbackTraitsBase { + static TVM_FFI_INLINE PrimExpr ConvertFallbackValue(StrictBool value); + static TVM_FFI_INLINE PrimExpr ConvertFallbackValue(int64_t value); + static TVM_FFI_INLINE PrimExpr ConvertFallbackValue(double value); + static TVM_FFI_INLINE PrimExpr ConvertFallbackValue(String value); + static TVM_FFI_INLINE PrimExpr ConvertFallbackValue(PrimExprConvertible value) { + return value->ToPrimExpr(); + } +}; +} // namespace ffi + /*! * \brief add operator * @@ -622,7 +661,7 @@ class Integer : public IntImm { * \param other another expression. */ Integer& operator=(const IntImm& other) { - data_ = ObjectRef::GetDataPtr(other); + data_ = ffi::details::ObjectUnsafe::ObjectPtrFromObjectRef(other); return *this; } /*! @@ -728,129 +767,63 @@ inline const TTypeNode* RelaxExprNode::type_as() const { return node; } -} // namespace tvm - -namespace tvm { -namespace runtime { - -// Automatic conversion into IntImm, Integer, and Bool, when called -// through the FFI. Automatic conversions into PrimExpr are -// registered in "tvm/tir/expr.h", as it includes conversions to the -// TIR-only StringImm. -// -// While the FFI only requires the From() method, these -// implementations also define a TryFrom() method to avoid duplicate -// logic in the PrimExpr conversion. - +namespace ffi { +// Type traits to enable automatic conversion into IntImm, Integer, and Bool +// when called through the FFI template <> -struct PackedFuncValueConverter { - template - static Optional TryFrom(const PODSubclass& val) { - if (auto opt = val.TryAsInt()) { - int64_t value = opt.value(); - auto dtype = - (value > std::numeric_limits::max() || value < std::numeric_limits::min()) - ? DataType::Int(64) - : DataType::Int(32); - return IntImm(dtype, value); - } else if (auto opt = val.TryAsBool()) { - return IntImm(DataType::Int(32), opt.value()); - } else { - return NullOpt; - } - } - - template - static tvm::IntImm From(const PODSubclass& val) { - if (auto opt = TryFrom(val)) { - return opt.value(); - } else { - return val.template AsObjectRef(); - } - } -}; +inline constexpr bool use_default_type_traits_v = false; +// specialize to enable implicit conversion from const char* template <> -struct PackedFuncValueConverter { - template - static tvm::Integer From(const PODSubclass& val) { - if (auto opt = PackedFuncValueConverter::TryFrom(val)) { - return Integer(opt.value()); - } else { - return val.template AsObjectRef(); - } +struct TypeTraits : public ObjectRefWithFallbackTraitsBase { + static TVM_FFI_INLINE IntImm ConvertFallbackValue(int64_t value) { + auto dtype = + (value > std::numeric_limits::max() || value < std::numeric_limits::min()) + ? DataType::Int(64) + : DataType::Int(32); + return IntImm(dtype, value); } }; template <> -struct PackedFuncValueConverter { - template - static Optional TryFrom(const PODSubclass& val) { - if (auto opt = val.TryAsBool()) { - return tvm::Bool(opt.value()); - } else if (auto opt = val.TryAsInt()) { - int value = opt.value(); - ICHECK(value == 0 || value == 1) - << "ValueError: boolean value can only be 0 or 1, but get " << value; - return tvm::Bool(static_cast(value)); - } else { - return NullOpt; - } - } +inline constexpr bool use_default_type_traits_v = false; - template - static tvm::Bool From(const PODSubclass& val) { - if (auto opt = TryFrom(val)) { - return opt.value(); - } else { - return val.template AsObjectRef(); - } - } +template <> +struct TypeTraits : public ObjectRefWithFallbackTraitsBase { + static TVM_FFI_INLINE Integer ConvertFallbackValue(int64_t value) { return Integer(value); } }; template <> -struct PackedFuncValueConverter { - static Optional TryFrom(const TVMPODValue_& val) { - if (auto opt = val.TryAsFloat()) { - return FloatImm(runtime::DataType::Float(32), opt.value()); - } else { - return NullOpt; - } - } +inline constexpr bool use_default_type_traits_v = false; - template - static tvm::FloatImm From(const PODSubclass& val) { - if (auto opt = TryFrom(val)) { - return opt.value(); - } else { - return val.template AsObjectRef(); - } +template <> +struct TypeTraits : public ObjectRefWithFallbackTraitsBase { + static TVM_FFI_INLINE FloatImm ConvertFallbackValue(double value) { + return FloatImm(runtime::DataType::Float(32), value); } }; -/* \brief Backwards compatibility wrapper for IntImm arguments - * - * In previous versions of TVM, IntImm was the default FFI type for - * integer arguments, instead of runtime::Int. For backwards - * compatibility where the callee has been updated to expected a - * runtime::Int, the caller has not been updated to provide a - * runtime::Int, and the auto-unboxing of - * runtime::Int does not apply (e.g. making an `Array`), - * allow the IntImm to be generated. - */ template <> -struct PackedFuncValueConverter { - template - static runtime::Int From(const PODSubclass& val) { - if (val.template IsObjectRef()) { - return runtime::Int(val.template AsObjectRef()->value); - } else { - return val.template AsObjectRef(); - } - } +inline constexpr bool use_default_type_traits_v = false; + +template <> +struct TypeTraits : public ObjectRefWithFallbackTraitsBase { + static TVM_FFI_INLINE Bool ConvertFallbackValue(int64_t value) { return Bool(value != 0); } }; -} // namespace runtime +// define automatic conversion from bool, int64_t, double to PrimExpr +TVM_FFI_INLINE PrimExpr TypeTraits::ConvertFallbackValue(StrictBool value) { + return IntImm(DataType::Bool(), value, Span()); +} + +TVM_FFI_INLINE PrimExpr TypeTraits::ConvertFallbackValue(int64_t value) { + return TypeTraits::ConvertFallbackValue(value); +} + +TVM_FFI_INLINE PrimExpr TypeTraits::ConvertFallbackValue(double value) { + return TypeTraits::ConvertFallbackValue(value); +} +} // namespace ffi } // namespace tvm /* \brief Allow tvm.GLobalVar as key in STL tables diff --git a/include/tvm/ir/function.h b/include/tvm/ir/function.h index 2282cb979b5e..cee94d37a5c0 100644 --- a/include/tvm/ir/function.h +++ b/include/tvm/ir/function.h @@ -161,9 +161,8 @@ class BaseFuncNode : public RelaxExprNode { * \endcode */ template - Optional GetAttr( - const std::string& attr_key, - Optional default_value = Optional(nullptr)) const { + Optional GetAttr(const std::string& attr_key, + Optional default_value = std::nullopt) const { return attrs.GetAttr(attr_key, default_value); } // variant that uses TObjectRef to enable implicit conversion to default value. diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index 0338096d7047..66637f67d948 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -91,7 +91,7 @@ class IRModuleNode : public Object { template Optional GetAttr( const std::string& attr_key, - Optional default_value = Optional(nullptr)) const { + Optional default_value = Optional(std::nullopt)) const { return attrs.GetAttr(attr_key, default_value); } // variant that uses TObjectRef to enable implicit conversion to default value. @@ -290,9 +290,6 @@ class IRModule : public ObjectRef { /*! \brief Declare the container type. */ using ContainerType = IRModuleNode; - /*! \brief Declare whether Ref is nullable. */ - static constexpr bool _type_is_nullable = false; - // allow copy on write. TVM_DEFINE_OBJECT_REF_COW_METHOD(IRModuleNode); }; @@ -305,8 +302,6 @@ namespace attr { * \brief Name of the module * * Type: String - * - * \sa tvm::runtime::String */ constexpr const char* kModuleName = "mod_name"; diff --git a/include/tvm/ir/name_supply.h b/include/tvm/ir/name_supply.h index 11dac3fe52ad..df0fba16cb6c 100644 --- a/include/tvm/ir/name_supply.h +++ b/include/tvm/ir/name_supply.h @@ -25,6 +25,7 @@ #define TVM_IR_NAME_SUPPLY_H_ #include +#include #include #include #include diff --git a/include/tvm/ir/op.h b/include/tvm/ir/op.h index 7fbd1cbb84f1..a37485700544 100644 --- a/include/tvm/ir/op.h +++ b/include/tvm/ir/op.h @@ -160,7 +160,7 @@ class Op : public RelaxExpr { */ TVM_DLL static const Op& Get(const String& op_name); - TVM_DEFINE_OBJECT_REF_METHODS(Op, RelaxExpr, OpNode) + TVM_DEFINE_OBJECT_REF_METHODS(Op, RelaxExpr, OpNode); private: /*! @@ -359,7 +359,7 @@ inline OpRegEntry& OpRegEntry::set_attrs_type() { // NOLINT(*) inline OpRegEntry& OpRegEntry::set_attrs_type_key(const String& key) { // NOLINT(*) get()->attrs_type_key = key; - get()->attrs_type_index = Object::TypeKey2Index(key); + get()->attrs_type_index = tvm::ffi::TypeKeyToIndex(key.c_str()); return *this; } @@ -372,9 +372,7 @@ template inline OpRegEntry& OpRegEntry::set_attr( // NOLINT(*) const std::string& attr_name, const ValueType& value, int plevel) { ICHECK_GT(plevel, 0) << "plevel in set_attr must be greater than 0"; - runtime::TVMRetValue rv; - rv = value; - UpdateAttr(attr_name, rv, plevel); + UpdateAttr(attr_name, Any(value), plevel); return *this; } diff --git a/include/tvm/ir/source_map.h b/include/tvm/ir/source_map.h index 9b3041f3c000..7b79a2c89455 100644 --- a/include/tvm/ir/source_map.h +++ b/include/tvm/ir/source_map.h @@ -214,7 +214,7 @@ class SourceMap; /*! * \brief Stores locations in frontend source that generated a node. */ -class SourceMapNode : public Object { +class SourceMapObj : public Object { public: /*! \brief The source mapping. */ Map source_map; @@ -222,12 +222,12 @@ class SourceMapNode : public Object { // override attr visitor void VisitAttrs(AttrVisitor* v) { v->Visit("source_map", &source_map); } - bool SEqualReduce(const SourceMapNode* other, SEqualReducer equal) const { + bool SEqualReduce(const SourceMapObj* other, SEqualReducer equal) const { return equal(source_map, other->source_map); } static constexpr const char* _type_key = "SourceMap"; - TVM_DECLARE_FINAL_OBJECT_INFO(SourceMapNode, Object); + TVM_DECLARE_FINAL_OBJECT_INFO(SourceMapObj, Object); }; class SourceMap : public ObjectRef { @@ -241,12 +241,12 @@ class SourceMap : public ObjectRef { void Add(const Source& source); - SourceMapNode* operator->() { + SourceMapObj* operator->() { ICHECK(get() != nullptr); - return static_cast(get_mutable()); + return static_cast(get_mutable()); } - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SourceMap, ObjectRef, SourceMapNode); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SourceMap, ObjectRef, SourceMapObj); }; } // namespace tvm diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index 503b2c8d141a..0da882f3884d 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -86,7 +86,7 @@ class PassContextNode : public Object { /*! \brief The diagnostic context. */ mutable Optional diag_ctx; /*! \brief Pass specific configurations. */ - Map config; + Map config; /*! \brief A list of pass instrument implementations. */ Array instruments; @@ -114,10 +114,9 @@ class PassContextNode : public Object { * \throw Error if the key exists but the value does not match TObjectRef. */ template - Optional GetConfig(const std::string& key, Optional default_value = - Optional(nullptr)) const { - static_assert(std::is_base_of::value, - "Can only call GetAttr with ObjectRef types."); + Optional GetConfig( + const std::string& key, + Optional default_value = Optional(std::nullopt)) const { if (!config.defined()) return default_value; auto it = config.find(key); if (it != config.end()) { @@ -267,36 +266,23 @@ class PassContext : public ObjectRef { * \tparam ValueType The value type to be registered */ template - static uint32_t RegisterConfigOption(const char* key) { - using ValueNodeType = typename ValueType::ContainerType; + static int32_t RegisterConfigOption(const char* key) { // NOTE: we could further update the function later. - uint32_t tindex = ValueNodeType::_GetOrAllocRuntimeTypeIndex(); - auto type_key = runtime::Object::TypeIndex2Key(tindex); - + int32_t tindex = ffi::TypeToRuntimeTypeIndex::v(); auto* reflection = ReflectionVTable::Global(); + auto type_key = ffi::TypeIndexToTypeKey(tindex); - auto legalization = [=](ObjectRef obj) -> ObjectRef { - if (obj->IsInstance::ContainerType>()) { - return reflection->CreateObject(type_key, Downcast>(obj)); + auto legalization = [=](ffi::Any value) -> ffi::Any { + if (auto opt_map = value.as>()) { + return reflection->CreateObject(type_key, opt_map.value()); } else { - // Backwards compatibility for config options defined prior to - // https://github.com/apache/tvm/pull/16183. This commit - // changed the default FFI conversion of python integers from - // `tvm::IntImm` to `runtime::Int`. - // - // This backwards compatibility fix can be removed when all - // options registered with TVM_REGISTER_PASS_CONFIG_OPTION are - // updated to use `runtime::Int` and `runtime::Bool`. - TVMRetValue ret; - ret = obj; - try { - ValueType legalized = ret; - return legalized; - } catch (Error& err) { - LOG(FATAL) << "AttributeError: expect config " << key << " to have type " << type_key - << ", but received error when converting to this type.\n" - << err.what(); + auto opt_val = value.as(); + if (!opt_val.has_value()) { + TVM_FFI_THROW(AttributeError) + << "Expect config " << key << " to have type " << type_key << ", but instead get " + << ffi::details::AnyUnsafe::GetMismatchTypeInfo(value); } + return value; } }; @@ -315,7 +301,7 @@ class PassContext : public ObjectRef { TVM_DLL void ExitWithScope(); // Register configuration key value type. TVM_DLL static void RegisterConfigOption(const char* key, uint32_t value_type_index, - std::function legalization); + std::function legalization); // Classes to get the Python `with` like syntax. friend class Internal; @@ -551,9 +537,9 @@ class Sequential : public Pass { * * \return The created module pass. */ -TVM_DLL Pass CreateModulePass( - const runtime::TypedPackedFunc& pass_func, int opt_level, - String name, Array required, bool traceable = false); +TVM_DLL Pass CreateModulePass(std::function pass_func, + int opt_level, String name, Array required, + bool traceable = false); /* * \brief Utility to apply a pass to specific functions in an IRModule diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index e0dbc7be50cf..b4da978b56c2 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -154,12 +154,12 @@ class ScheduleRule : public runtime::ObjectRef { * ignored by default. This function should return True for a block that should be tiled. * \return The schedule rule created */ - TVM_DLL static ScheduleRule MultiLevelTiling(String structure, // - Optional> tile_binds, // - Optional max_innermost_factor, // - Optional> vector_load_lens, // - Optional> reuse_read, // - Optional> reuse_write, + TVM_DLL static ScheduleRule MultiLevelTiling(String structure, // + Optional> tile_binds, // + Optional max_innermost_factor, // + Optional> vector_load_lens, // + Optional> reuse_read, // + Optional> reuse_write, Optional filter_fn = NullOpt); /*! @@ -182,7 +182,7 @@ class ScheduleRule : public runtime::ObjectRef { TVM_DLL static ScheduleRule MultiLevelTilingWithIntrin( String intrin_name, String structure, Optional> tile_binds, Optional max_innermost_factor, Optional> vector_load_lens, - Optional> reuse_read, Optional> reuse_write); + Optional> reuse_read, Optional> reuse_write); /*! * \brief Extension of MultiLevelTiling for auto-tensorization with multiple groups of candidate @@ -207,8 +207,8 @@ class ScheduleRule : public runtime::ObjectRef { TVM_DLL static ScheduleRule MultiLevelTilingTensorCore( Array> intrin_groups, String structure, Optional> tile_binds, Optional max_innermost_factor, - Optional> vector_load_lens, Optional> reuse_read, - Optional> reuse_write, bool use_software_pipeline); + Optional> vector_load_lens, Optional> reuse_read, + Optional> reuse_write, bool use_software_pipeline); /*! * \brief Extension of MultiLevelTiling for backends with wide vectors. @@ -223,7 +223,7 @@ class ScheduleRule : public runtime::ObjectRef { */ TVM_DLL static ScheduleRule MultiLevelTilingWideVector( String structure, Integer vector_length_in_bits, Optional max_innermost_factor, - Optional> reuse_read, Optional> reuse_write); + Optional> reuse_read, Optional> reuse_write); /*! * \brief Create a rule: add-rfactor to some blocks if needed @@ -241,7 +241,7 @@ class ScheduleRule : public runtime::ObjectRef { * \param thread_extents Candidates of thread axis extent (values are required to be positive). * \return The schedule rule created */ - TVM_DLL static ScheduleRule CrossThreadReduction(Array thread_extents); + TVM_DLL static ScheduleRule CrossThreadReduction(Array thread_extents); /*! * \brief A rule that randomly select a compute-at location for a free block * \return The schedule rule created @@ -260,9 +260,9 @@ class ScheduleRule : public runtime::ObjectRef { * \param unroll_explicit Whether to explicitly unroll the loop, or just add an "unroll" pragma. * \return The schedule rule created */ - TVM_DLL static ScheduleRule ParallelizeVectorizeUnroll(int max_jobs_per_core, // - int max_vectorize_extent, // - Array unroll_max_steps, // + TVM_DLL static ScheduleRule ParallelizeVectorizeUnroll(int max_jobs_per_core, // + int max_vectorize_extent, // + Array unroll_max_steps, // bool unroll_explicit); /*! * \brief Auto bind loops around the block to BlockIdx and ThreadIdx diff --git a/include/tvm/meta_schedule/search_strategy.h b/include/tvm/meta_schedule/search_strategy.h index 3f44a2438d22..aeef1bff306c 100644 --- a/include/tvm/meta_schedule/search_strategy.h +++ b/include/tvm/meta_schedule/search_strategy.h @@ -41,38 +41,38 @@ class SearchStrategy; /*! * \brief The search strategy for measure candidates generation. * \note The relationship between SearchStrategy and other classes are as follows: - ┌──────────────────────────────────────────────────────────────┐ - ┌──┴───────────────────────────────────────────────────────────┐ │ -┌──┴────────────────── Tune Context ───────────────────────────┐ │ │ -│ ┌─────────────────────┐ │ │ │ -│ │ │ Generate │ │ │ -│ │ Space Generator ├──────────────┐ │ │ │ -│ │ │ │ │ │ │ -│ └─────────────────────┘ ▼ │ │ │ -│ Design Space │ │ │ -│ ┌─────────────────────┐ │ │ │ │ -│ Generate │ │ Pretuning │ │ │ │ -│ ┌───────────┤ Search Strategy │◄─────────────┘ │ │ │ -│ │ │ │ │ ├──┘ -│ │ └─────────────────────┘ ├──┘ -└────┼─────────────────────────────────────────────────────────┘ - │ - │ -┌────┼──────────────── Managed By Task Scheduler ─────────────────────┐ -│ │ ┌───────────┐ │ -│ │ Send to │ │ Send to │ -│ ▼ ┌─────────────►│ Builder ├──────────┐ │ -│ Measure Candidate │ Builder │ │ Runner │ │ -│ │ │ └───────────┘ │ │ -│ │ ┌────────────┴────────┐ │ │ -│ │ │ │ ┌───────────┐ │ │ -│ └────►│ Task Scheduler │ │ │ │ │ -│ │ │ │ Runner │◄─────────┘ │ -│ └─────────────────────┘ │ │ │ -│ ▲ └─────┬─────┘ │ -│ │ │ │ -│ └─── Runner Future ◄────┘ │ -└─────────────────────────────────────────────────────────────────────┘ + +--------------------------------------------------------------+ + +--+-----------------------------------------------------------+ | + +--+------------------ Tune Context -----------------------------+ | | + | +---------------------+ | | | + | | | Generate | | | + | | Space Generator +--------------+ | | | + | | | | | | | + | +---------------------+ v | | | + | Design Space | | | + | +---------------------+ | | | | + | Generate | | Pretuning | | | | + | +-----------+ Search Strategy |<-------------+ | | | + | | | | | +--+ + | | +---------------------+ +--+ + +----+----------------------------------------------------------+ + | + | + +----+---------------- Managed By Task Scheduler ---------------------+ + | | +-----------+ | + | | Send to | | Send to | + | v +-------------+| Builder +----------+ | + | Measure Candidate | Builder | | Runner | | + | | | +-----------+ | | + | | +------------+------------+ | | + | | | | +-----------+ | | + | +---->| Task Scheduler | | | | | + | | | | Runner |<-----+ | + | +-------------------------+ | | | + | ^ +-----+-----+ | + | | | | + | +---- Runner Future <-------+ | + +---------------------------------------------------------------------+ */ class SearchStrategyNode : public runtime::Object { public: diff --git a/include/tvm/meta_schedule/space_generator.h b/include/tvm/meta_schedule/space_generator.h index f746eb809194..650320d1e21c 100644 --- a/include/tvm/meta_schedule/space_generator.h +++ b/include/tvm/meta_schedule/space_generator.h @@ -40,38 +40,38 @@ class SpaceGenerator; /*! * \brief The abstract class for design space generation. * \note The relationship between SpaceGenerator and other classes are as follows: - ┌──────────────────────────────────────────────────────────────┐ - ┌──┴───────────────────────────────────────────────────────────┐ │ -┌──┴────────────────── Tune Context ───────────────────────────┐ │ │ -│ ┌─────────────────────┐ │ │ │ -│ │ │ Generate │ │ │ -│ │ Space Generator ├──────────────┐ │ │ │ -│ │ │ │ │ │ │ -│ └─────────────────────┘ ▼ │ │ │ -│ Design Space │ │ │ -│ ┌─────────────────────┐ │ │ │ │ -│ Generate │ │ Pretuning │ │ │ │ -│ ┌───────────┤ Search Strategy │◄─────────────┘ │ │ │ -│ │ │ │ │ ├──┘ -│ │ └─────────────────────┘ ├──┘ -└────┼─────────────────────────────────────────────────────────┘ - │ - │ -┌────┼──────────────── Managed By Task Scheduler ─────────────────────┐ -│ │ ┌───────────┐ │ -│ │ Send to │ │ Send to │ -│ ▼ ┌─────────────►│ Builder ├──────────┐ │ -│ Measure Candidate │ Builder │ │ Runner │ │ -│ │ │ └───────────┘ │ │ -│ │ ┌────────────┴────────┐ │ │ -│ │ │ │ ┌───────────┐ │ │ -│ └────►│ Task Scheduler │ │ │ │ │ -│ │ │ │ Runner │◄─────────┘ │ -│ └─────────────────────┘ │ │ │ -│ ▲ └─────┬─────┘ │ -│ │ │ │ -│ └─── Runner Future ◄────┘ │ -└─────────────────────────────────────────────────────────────────────┘ + +--------------------------------------------------------------+ + +--+-----------------------------------------------------------+ | + +--+------------------ Tune Context -----------------------------+ | | + | +---------------------+ | | | + | | | Generate | | | + | | Space Generator +--------------+ | | | + | | | | | | | + | +---------------------+ v | | | + | Design Space | | | + | +---------------------+ | | | | + | Generate | | Pretuning | | | | + | +-----------+ Search Strategy |<-------------+ | | | + | | | | | +--+ + | | +---------------------+ +--+ + +----+----------------------------------------------------------+ + | + | + +----+---------------- Managed By Task Scheduler ---------------------+ + | | +-----------+ | + | | Send to | | Send to | + | v +-------------+| Builder +----------+ | + | Measure Candidate | Builder | | Runner | | + | | | +-----------+ | | + | | +------------+------------+ | | + | | | | +-----------+ | | + | +---->| Task Scheduler | | | | | + | | | | Runner |<-----+ | + | +-------------------------+ | | | + | ^ +-----+-----+ | + | | | | + | +---- Runner Future <-------+ | + +---------------------------------------------------------------------+ */ class SpaceGeneratorNode : public runtime::Object { public: diff --git a/include/tvm/meta_schedule/task_scheduler.h b/include/tvm/meta_schedule/task_scheduler.h index f4fc491286dd..8cc3595d68d8 100644 --- a/include/tvm/meta_schedule/task_scheduler.h +++ b/include/tvm/meta_schedule/task_scheduler.h @@ -92,38 +92,38 @@ class TaskRecord : public runtime::ObjectRef { /*! * \brief The abstract interface of task schedulers. * \note The relationship between SpaceGenerator and other classes are as follows: - ┌──────────────────────────────────────────────────────────────┐ - ┌──┴───────────────────────────────────────────────────────────┐ │ -┌──┴────────────────── Tune Context ───────────────────────────┐ │ │ -│ ┌─────────────────────┐ │ │ │ -│ │ │ Generate │ │ │ -│ │ Space Generator ├──────────────┐ │ │ │ -│ │ │ │ │ │ │ -│ └─────────────────────┘ ▼ │ │ │ -│ Design Space │ │ │ -│ ┌─────────────────────┐ │ │ │ │ -│ Generate │ │ Pretuning │ │ │ │ -│ ┌───────────┤ Search Strategy │◄─────────────┘ │ │ │ -│ │ │ │ │ ├──┘ -│ │ └─────────────────────┘ ├──┘ -└────┼─────────────────────────────────────────────────────────┘ - │ - │ -┌────┼──────────────── Managed By Task Scheduler ─────────────────────┐ -│ │ ┌───────────┐ │ -│ │ Send to │ │ Send to │ -│ ▼ ┌─────────────►│ Builder ├──────────┐ │ -│ Measure Candidate │ Builder │ │ Runner │ │ -│ │ │ └───────────┘ │ │ -│ │ ┌────────────┴────────┐ │ │ -│ │ │ │ ┌───────────┐ │ │ -│ └────►│ Task Scheduler │ │ │ │ │ -│ │ │ │ Runner │◄─────────┘ │ -│ └─────────────────────┘ │ │ │ -│ ▲ └─────┬─────┘ │ -│ │ │ │ -│ └─── Runner Future ◄────┘ │ -└─────────────────────────────────────────────────────────────────────┘ + +--------------------------------------------------------------+ + +--+-----------------------------------------------------------+ | + +--+------------------ Tune Context -----------------------------+ | | + | +---------------------+ | | | + | | | Generate | | | + | | Space Generator +--------------+ | | | + | | | | | | | + | +---------------------+ v | | | + | Design Space | | | + | +---------------------+ | | | | + | Generate | | Pretuning | | | | + | +-----------+ Search Strategy |<-------------+ | | | + | | | | | +--+ + | | +---------------------+ +--+ + +----+----------------------------------------------------------+ + | + | + +----+---------------- Managed By Task Scheduler ---------------------+ + | | +-----------+ | + | | Send to | | Send to | + | v +-------------+| Builder +----------+ | + | Measure Candidate | Builder | | Runner | | + | | | +-----------+ | | + | | +------------+------------+ | | + | | | | +-----------+ | | + | +---->| Task Scheduler | | | | | + | | | | Runner |<-----+ | + | +-------------------------+ | | | + | ^ +-----+-----+ | + | | | | + | +---- Runner Future <-------+ | + +---------------------------------------------------------------------+ */ class TaskSchedulerNode : public runtime::Object { public: diff --git a/include/tvm/node/attr_registry_map.h b/include/tvm/node/attr_registry_map.h index c4b54ef0f27d..909d0ad0d4f9 100644 --- a/include/tvm/node/attr_registry_map.h +++ b/include/tvm/node/attr_registry_map.h @@ -74,7 +74,11 @@ class AttrRegistryMapContainerMap { ICHECK(key.defined()); const uint32_t idx = key->AttrRegistryIndex(); if (idx < data_.size() && data_[idx].second != 0) { - return data_[idx].first; + if constexpr (std::is_same_v) { + return data_[idx].first; + } else { + return data_[idx].first.template cast(); + } } else { return def_value; } @@ -116,7 +120,13 @@ class AttrRegistryMap { * \param key The key to the map * \return the const reference to the content value. */ - ValueType operator[](const KeyType& key) const { return map_[key]; } + ValueType operator[](const KeyType& key) const { + if constexpr (std::is_same_v) { + return map_[key]; + } else { + return map_[key].template cast(); + } + } /*! * \brief get the corresponding value element at key with default value. * \param key The key to the map diff --git a/include/tvm/node/object_path.h b/include/tvm/node/object_path.h index 97a62bfd2d8f..b2bfa4f27379 100644 --- a/include/tvm/node/object_path.h +++ b/include/tvm/node/object_path.h @@ -93,7 +93,7 @@ class ObjectPathNode : public Object { ObjectPath MissingArrayElement(int32_t index) const; /*! \brief Extend this path with access to a map value. */ - ObjectPath MapValue(ObjectRef key) const; + ObjectPath MapValue(ffi::Any key) const; /*! \brief Extend this path with access to a missing map entry. */ ObjectPath MissingMapEntry() const; @@ -245,9 +245,9 @@ class MissingArrayElementPath : public ObjectPath { class MapValuePathNode : public ObjectPathNode { public: /*! \brief Key of the map entry that is being accessed */ - ObjectRef key; + ffi::Any key; - explicit MapValuePathNode(const ObjectPathNode* parent, ObjectRef key); + explicit MapValuePathNode(const ObjectPathNode* parent, ffi::Any key); static constexpr const char* _type_key = "MapValuePath"; TVM_DECLARE_FINAL_OBJECT_INFO(MapValuePathNode, ObjectPathNode); diff --git a/include/tvm/node/reflection.h b/include/tvm/node/reflection.h index 20e482ae72e7..0ee09e70f474 100644 --- a/include/tvm/node/reflection.h +++ b/include/tvm/node/reflection.h @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -62,7 +63,10 @@ class AttrVisitor { TVM_DLL virtual void Visit(const char* key, void** value) = 0; TVM_DLL virtual void Visit(const char* key, DataType* value) = 0; TVM_DLL virtual void Visit(const char* key, runtime::NDArray* value) = 0; - TVM_DLL virtual void Visit(const char* key, runtime::ObjectRef* value) = 0; + TVM_DLL virtual void Visit(const char* key, ffi::ObjectRef* value) = 0; + TVM_DLL virtual void Visit(const char* key, Optional* value) = 0; + TVM_DLL virtual void Visit(const char* key, Optional* value) = 0; + template ::value>::type> void Visit(const char* key, ENum* ptr) { static_assert(std::is_same::type>::value, @@ -160,7 +164,7 @@ class ReflectionVTable { * \param kwargs The field arguments. * \return The created object. */ - TVM_DLL ObjectRef CreateObject(const std::string& type_key, const Map& kwargs); + TVM_DLL ObjectRef CreateObject(const std::string& type_key, const Map& kwargs); /*! * \brief Get an field object by the attr name. * \param self The pointer to the object. diff --git a/include/tvm/node/repr_printer.h b/include/tvm/node/repr_printer.h index 3293b43564cc..30bfe8e95193 100644 --- a/include/tvm/node/repr_printer.h +++ b/include/tvm/node/repr_printer.h @@ -43,6 +43,8 @@ class ReprPrinter { /*! \brief The node to be printed. */ TVM_DLL void Print(const ObjectRef& node); + /*! \brief The node to be printed. */ + TVM_DLL void Print(const ffi::Any& node); /*! \brief Print indent to the stream */ TVM_DLL void PrintIndent(); // Allow registration to be printer. @@ -91,7 +93,7 @@ TVM_DLL void Dump(const runtime::Object* node); } // namespace tvm namespace tvm { -namespace runtime { +namespace ffi { // default print function for all objects // provide in the runtime namespace as this is where objectref originally comes from. inline std::ostream& operator<<(std::ostream& os, const ObjectRef& n) { // NOLINT(*) @@ -99,12 +101,24 @@ inline std::ostream& operator<<(std::ostream& os, const ObjectRef& n) { // NOLI return os; } +// default print function for any +inline std::ostream& operator<<(std::ostream& os, const Any& n) { // NOLINT(*) + ReprPrinter(os).Print(n); + return os; +} + +template +inline std::ostream& operator<<(std::ostream& os, const Variant& n) { // NOLINT(*) + ReprPrinter(os).Print(Any(n)); + return os; +} + inline std::string AsLegacyRepr(const ObjectRef& n) { std::ostringstream os; ReprLegacyPrinter(os).Print(n); return os.str(); } -} // namespace runtime -using runtime::AsLegacyRepr; +} // namespace ffi +using ffi::AsLegacyRepr; } // namespace tvm #endif // TVM_NODE_REPR_PRINTER_H_ diff --git a/include/tvm/node/script_printer.h b/include/tvm/node/script_printer.h index 8e0cafb494fb..9d2fa1023e92 100644 --- a/include/tvm/node/script_printer.h +++ b/include/tvm/node/script_printer.h @@ -26,6 +26,9 @@ #include #include #include +#include +#include +#include #include #include @@ -148,7 +151,7 @@ class PrinterConfigNode : public Object { class PrinterConfig : public ObjectRef { public: - explicit PrinterConfig(Map config_dict = Map()); + explicit PrinterConfig(Map config_dict = Map()); TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(PrinterConfig, runtime::ObjectRef, PrinterConfigNode); diff --git a/include/tvm/node/serialization.h b/include/tvm/node/serialization.h index ac675946e0eb..c99d0f7f73fb 100644 --- a/include/tvm/node/serialization.h +++ b/include/tvm/node/serialization.h @@ -36,7 +36,7 @@ namespace tvm { * * \return the string representation of the node. */ -TVM_DLL std::string SaveJSON(const runtime::ObjectRef& node); +TVM_DLL std::string SaveJSON(ffi::Any node); /*! * \brief Internal implementation of LoadJSON @@ -45,7 +45,7 @@ TVM_DLL std::string SaveJSON(const runtime::ObjectRef& node); * * \return The shared_ptr of the Node. */ -TVM_DLL runtime::ObjectRef LoadJSON(std::string json_str); +TVM_DLL ffi::Any LoadJSON(std::string json_str); } // namespace tvm #endif // TVM_NODE_SERIALIZATION_H_ diff --git a/include/tvm/node/structural_equal.h b/include/tvm/node/structural_equal.h index 2dd732c72a4d..7f56fd6ca961 100644 --- a/include/tvm/node/structural_equal.h +++ b/include/tvm/node/structural_equal.h @@ -58,6 +58,12 @@ class BaseValueEqual { bool operator()(const int64_t& lhs, const int64_t& rhs) const { return lhs == rhs; } bool operator()(const uint64_t& lhs, const uint64_t& rhs) const { return lhs == rhs; } + bool operator()(const Optional& lhs, const Optional& rhs) const { + return lhs == rhs; + } + bool operator()(const Optional& lhs, const Optional& rhs) const { + return lhs == rhs; + } bool operator()(const int& lhs, const int& rhs) const { return lhs == rhs; } bool operator()(const bool& lhs, const bool& rhs) const { return lhs == rhs; } bool operator()(const std::string& lhs, const std::string& rhs) const { return lhs == rhs; } @@ -122,6 +128,24 @@ class StructuralEqual : public BaseValueEqual { */ TVM_DLL bool operator()(const ObjectRef& lhs, const ObjectRef& rhs, const bool map_free_params = false) const; + + /*! + * \brief Compare any value via strutural equal. + * \param lhs The left operand. + * \param rhs The right operand. + * \param map_free_params Whether or not to map free variables. + * \return The comparison result. + */ + TVM_FFI_INLINE bool operator()(const ffi::Any& lhs, const ffi::Any& rhs, + bool map_free_params = false) const { + if (lhs.type_index() != rhs.type_index()) return false; + if (lhs.type_index() >= ffi::TypeIndex::kTVMFFIStaticObjectBegin) { + return operator()(lhs.cast(), rhs.cast(), map_free_params); + } + // POD value can always use v_int64 to get the hash value + return (ffi::details::AnyUnsafe::TVMFFIAnyPtrFromAny(lhs)->v_uint64 == + ffi::details::AnyUnsafe::TVMFFIAnyPtrFromAny(rhs)->v_uint64); + } }; /*! @@ -185,6 +209,19 @@ class SEqualReducer { */ virtual void MarkGraphNode() = 0; + /*! + * \brief Map lhs to rhs. + * \param lhs The left operand. + * \return The corresponding rhs value if any, nullptr if not available. + */ + TVM_FFI_INLINE ffi::Any MapLhsToRhs(const ffi::Any& lhs) { + if (lhs.type_index() >= ffi::TypeIndex::kTVMFFIStaticObjectBegin) { + return MapLhsToRhs(lhs.cast()); + } else { + return lhs; + } + } + protected: using PathTracingData = SEqualReducer::PathTracingData; }; @@ -227,7 +264,10 @@ class SEqualReducer { Optional paths = NullOpt) const; bool operator()(const DataType& lhs, const DataType& rhs, Optional paths = NullOpt) const; - + bool operator()(const Optional& lhs, const Optional& rhs, + Optional paths = NullOpt) const; + bool operator()(const Optional& lhs, const Optional& rhs, + Optional paths = NullOpt) const; template ::value>::type> bool operator()(const ENum& lhs, const ENum& rhs, Optional paths = NullOpt) const { @@ -257,7 +297,7 @@ class SEqualReducer { * \param rhs The right operand. * \return the immediate check result. */ - bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const; + bool operator()(const ffi::ObjectRef& lhs, const ffi::ObjectRef& rhs) const; /*! * \brief Reduce condition to comparison of two objects. @@ -279,6 +319,16 @@ class SEqualReducer { return ObjectAttrsEqual(lhs, rhs, map_free_vars_, &paths); } + /* + * \brief Compare two Any values. + * \param lhs The left operand. + * \param rhs The right operand. + * \param paths Object paths for `lhs` and `rhs`. + * \return the immediate check result. + */ + bool AnyEqual(const ffi::Any& lhs, const ffi::Any& rhs, + Optional paths = NullOpt) const; + /*! * \brief Reduce condition to comparison of two definitions, * where free vars can be mapped. @@ -398,7 +448,7 @@ class SEqualHandlerDefault : public SEqualReducer::Handler { * \param map_free_vars Whether or not to remap variables if possible. * \return The equality result. */ - virtual bool Equal(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars); + virtual bool Equal(const ffi::Any& lhs, const ffi::Any& rhs, bool map_free_vars); protected: /*! diff --git a/include/tvm/node/structural_hash.h b/include/tvm/node/structural_hash.h index 553f284b8c5a..ca547b17149a 100644 --- a/include/tvm/node/structural_hash.h +++ b/include/tvm/node/structural_hash.h @@ -76,7 +76,29 @@ class BaseValueHash { return Reinterpret(static_cast(key)); } uint64_t operator()(const std::string& key) const { - return runtime::String::StableHashBytes(key.data(), key.length()); + return tvm::ffi::details::StableHashBytes(key.data(), key.length()); + } + uint64_t operator()(const Optional& key) const { + if (key.has_value()) { + return Reinterpret(*key); + } else { + return 0; + } + } + uint64_t operator()(const Optional& key) const { + if (key.has_value()) { + return Reinterpret(*key); + } else { + return 0; + } + } + /*! + * \brief Compute structural hash value for a POD value in Any. + * \param key The Any object. + * \return The hash value. + */ + TVM_FFI_INLINE uint64_t HashPODValueInAny(const ffi::Any& key) const { + return ffi::details::AnyUnsafe::TVMFFIAnyPtrFromAny(key)->v_uint64; } }; @@ -100,7 +122,19 @@ class StructuralHash : public BaseValueHash { * \param key The left operand. * \return The hash value. */ - TVM_DLL uint64_t operator()(const ObjectRef& key) const; + TVM_DLL uint64_t operator()(const ffi::ObjectRef& key) const; + + /** + * \brief Compute structural hashing value for an Any object. + * \param key The Any object. + * \return The hash value. + */ + TVM_FFI_INLINE uint64_t operator()(const ffi::Any& key) const { + if (key.type_index() >= ffi::TypeIndex::kTVMFFIStaticObjectBegin) { + return operator()(key.cast()); + } + return HashPODValueInAny(key); + } }; /*! @@ -185,6 +219,17 @@ class SHashReducer { // handle normal values. handler_->SHashReduceHashedValue(BaseValueHash()(key)); } + /** + * \brief Push hash of Any object to the current sequence of hash values. + * \param key The Any object. + */ + void operator()(const ffi::Any& key) const { + if (key.type_index() >= ffi::TypeIndex::kTVMFFIStaticObjectBegin) { + return operator()(key.cast()); + } + // POD value can always use v_int64 to get the hash value + handler_->SHashReduceHashedValue(BaseValueHash().HashPODValueInAny(key)); + } /*! * \brief Push hash of key to the current sequence of hash values. * \param key The key to be hashed. @@ -240,7 +285,7 @@ class SHashHandlerDefault : public SHashReducer::Handler { * \param map_free_vars Whether or not to remap variables if possible. * \return The hash result. */ - virtual uint64_t Hash(const ObjectRef& object, bool map_free_vars); + virtual uint64_t Hash(const ffi::Any& object, bool map_free_vars); protected: /*! diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h index b658758e3c8f..48cc35fcb886 100644 --- a/include/tvm/relax/analysis.h +++ b/include/tvm/relax/analysis.h @@ -541,8 +541,8 @@ TVM_DLL bool HasReshapePattern(const tir::PrimFunc& func); * Also, an impure call in a *nested* function does *not* mean that the outer expression contains * an impure call--it only does if the nested function is *later called*. */ -TVM_DLL Optional FindImpureCall(const Expr& expr, - const Optional& own_name = Optional(nullptr)); +TVM_DLL Optional FindImpureCall( + const Expr& expr, const Optional& own_name = Optional(std::nullopt)); /*! * \brief Check if the given expression (likely a function body) contains any impure calls. @@ -556,7 +556,7 @@ TVM_DLL Optional FindImpureCall(const Expr& expr, * an impure call--it only does if the nested function is *later called*. */ TVM_DLL bool ContainsImpureCall(const Expr& expr, - const Optional& own_name = Optional(nullptr)); + const Optional& own_name = Optional(std::nullopt)); /*! * \brief Check if the IRModule is well formed. diff --git a/include/tvm/relax/attrs/index.h b/include/tvm/relax/attrs/index.h index aa6c2e146104..d0fdfa129863 100644 --- a/include/tvm/relax/attrs/index.h +++ b/include/tvm/relax/attrs/index.h @@ -31,7 +31,7 @@ namespace relax { /*! \brief Attributes used in take operator */ struct TakeAttrs : public tvm::AttrsNode { - Optional axis; + Optional axis; TVM_DECLARE_ATTRS(TakeAttrs, "relax.attrs.TakeAttrs") { TVM_ATTR_FIELD(axis).describe("The axis over which to select values."); diff --git a/include/tvm/relax/attrs/manipulate.h b/include/tvm/relax/attrs/manipulate.h index 943d2f4d0d71..6a8e41116a11 100644 --- a/include/tvm/relax/attrs/manipulate.h +++ b/include/tvm/relax/attrs/manipulate.h @@ -32,7 +32,7 @@ namespace relax { /*! \brief Attributes used in concat operators */ struct ConcatAttrs : public tvm::AttrsNode { - Optional axis; + Optional axis; TVM_DECLARE_ATTRS(ConcatAttrs, "relax.attrs.ConcatAttrs") { TVM_ATTR_FIELD(axis).describe( @@ -135,7 +135,7 @@ struct StackAttrs : public tvm::AttrsNode { /*! \brief Attributes used in repeat operators */ struct RepeatAttrs : public tvm::AttrsNode { int repeats; - Optional axis; + Optional axis; TVM_DECLARE_ATTRS(RepeatAttrs, "relax.attrs.RepeatAttrs") { TVM_ATTR_FIELD(repeats).describe("The number of repetitions."); diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h index f0f80ad8f4a0..911a7d449e3b 100644 --- a/include/tvm/relax/attrs/nn.h +++ b/include/tvm/relax/attrs/nn.h @@ -586,7 +586,7 @@ struct AttentionAttrs : public tvm::AttrsNode { /*! \brief Attributes used for the padding operator */ struct PadAttrs : public tvm::AttrsNode { Array pad_width; - runtime::Float pad_value = 0.0; + double pad_value = 0.0; tvm::String pad_mode; TVM_DECLARE_ATTRS(PadAttrs, "relax.attrs.PadAttrs") { diff --git a/include/tvm/relax/attrs/op.h b/include/tvm/relax/attrs/op.h index 10c267b21faa..e97d82986e26 100644 --- a/include/tvm/relax/attrs/op.h +++ b/include/tvm/relax/attrs/op.h @@ -33,7 +33,7 @@ namespace relax { /*! \brief Attributes used in call_tir_with_grad */ struct CallTIRWithGradAttrs : public tvm::AttrsNode { String te_grad_name; - Map te_grad_kwargs; + Map te_grad_kwargs; TVM_DECLARE_ATTRS(CallTIRWithGradAttrs, "relax.attrs.CallTIRWithGradAttrs") { TVM_ATTR_FIELD(te_grad_name) diff --git a/include/tvm/relax/attrs/search.h b/include/tvm/relax/attrs/search.h index f3854078f11e..b6adb3e437c6 100644 --- a/include/tvm/relax/attrs/search.h +++ b/include/tvm/relax/attrs/search.h @@ -31,7 +31,7 @@ namespace relax { /*! \brief Attributes for search operators */ struct ArgmaxArgminAttrs : public tvm::AttrsNode { - Optional axis; + Optional axis; bool keepdims; TVM_DECLARE_ATTRS(ArgmaxArgminAttrs, "relax.attrs.ArgmaxArgminAttrs") { diff --git a/include/tvm/relax/attrs/statistical.h b/include/tvm/relax/attrs/statistical.h index 9f9a2fa87064..42f2eff0d126 100644 --- a/include/tvm/relax/attrs/statistical.h +++ b/include/tvm/relax/attrs/statistical.h @@ -44,7 +44,7 @@ struct StatisticalAttrs : public tvm::AttrsNode { /*! \brief Attributes used in scan operators like cumsum, cumprod */ struct ScanopAttrs : public tvm::AttrsNode { - Optional axis; + Optional axis; DataType dtype; Bool exclusive = Bool(false); diff --git a/include/tvm/relax/block_builder.h b/include/tvm/relax/block_builder.h index 070aef2fcb6d..5efe91a5e437 100644 --- a/include/tvm/relax/block_builder.h +++ b/include/tvm/relax/block_builder.h @@ -257,7 +257,6 @@ class BlockBuilderNode : public Object { */ virtual arith::Analyzer* GetAnalyzer() = 0; - static constexpr const uint32_t _type_index = TypeIndex::kDynamic; static constexpr const char* _type_key = "relax.BlockBuilder"; TVM_DECLARE_BASE_OBJECT_INFO(BlockBuilderNode, Object); }; diff --git a/include/tvm/relax/dataflow_pattern.h b/include/tvm/relax/dataflow_pattern.h index b3bbebd0e06c..36fac906c4de 100644 --- a/include/tvm/relax/dataflow_pattern.h +++ b/include/tvm/relax/dataflow_pattern.h @@ -113,7 +113,7 @@ class DFPattern : public ObjectRef { /*! \brief Syntatic Sugar for creating a NotPattern */ TVM_DLL NotPattern operator~() const; /*! \brief Syntatic Sugar for creating an AttrPattern */ - TVM_DLL AttrPattern HasAttr(const Map& attrs) const; + TVM_DLL AttrPattern HasAttr(const Map& attrs) const; /*! \brief Syntatic Sugar for creating a StructInfoPattern */ TVM_DLL StructInfoPattern HasStructInfo(const StructInfo& struct_info) const; /*! \brief Syntatic Sugar for creating a TypePattern */ diff --git a/include/tvm/relax/distributed/transform.h b/include/tvm/relax/distributed/transform.h index 31727b181ec9..99b23331f70a 100644 --- a/include/tvm/relax/distributed/transform.h +++ b/include/tvm/relax/distributed/transform.h @@ -41,6 +41,8 @@ using PassInfo = tvm::transform::PassInfo; using PassContext = tvm::transform::PassContext; using Function = tvm::relax::Function; using DataflowBlock = tvm::relax::DataflowBlock; +using tvm::transform::CreateModulePass; + /*! * \brief Propagate sharding information. * diff --git a/include/tvm/relax/exec_builder.h b/include/tvm/relax/exec_builder.h index 8940408f8048..c5103719d028 100644 --- a/include/tvm/relax/exec_builder.h +++ b/include/tvm/relax/exec_builder.h @@ -139,7 +139,6 @@ class ExecBuilderNode : public Object { void VisitAttrs(AttrVisitor* v) {} - static constexpr const uint32_t _type_index = TypeIndex::kDynamic; static constexpr const char* _type_key = "relax.ExecBuilder"; TVM_DECLARE_FINAL_OBJECT_INFO(ExecBuilderNode, Object); diff --git a/include/tvm/relax/expr_functor.h b/include/tvm/relax/expr_functor.h index 4904b02960a6..96b5b20d1ef8 100644 --- a/include/tvm/relax/expr_functor.h +++ b/include/tvm/relax/expr_functor.h @@ -71,7 +71,7 @@ class ExprFunctor; #define PY_EXPR_MUTATOR_DEFAULT(N, PY_FUNC, DEFAULT_FUNC, RET_TYPE) \ { \ if (PY_FUNC != nullptr) { \ - RET_TYPE ret = PY_FUNC(N); \ + RET_TYPE ret = PY_FUNC(N).cast(); \ return ret; \ } else { \ return DEFAULT_FUNC; \ @@ -89,7 +89,7 @@ class ExprFunctor; #define PY_EXPR_MUTATOR_DISPATCH(OP, PY_FUNC) \ vtable.template set_dispatch([](const ObjectRef& n, TSelf* self) { \ if (self->PY_FUNC != nullptr) { \ - Expr expr = self->PY_FUNC(n); \ + Expr expr = self->PY_FUNC(n).cast(); \ return expr; \ } else { \ return self->VisitExpr_(static_cast(n.get())); \ diff --git a/include/tvm/relax/nested_msg.h b/include/tvm/relax/nested_msg.h index 1ad5f02e0763..0ddcb271ab83 100644 --- a/include/tvm/relax/nested_msg.h +++ b/include/tvm/relax/nested_msg.h @@ -130,7 +130,7 @@ class NestedMsg : public ObjectRef { */ explicit NestedMsg(ObjectPtr ptr) : ObjectRef(ptr) {} /*! \brief Nullopt handling */ - NestedMsg(runtime::NullOptType) {} // NOLINT(*) + NestedMsg(std::nullopt_t) {} // NOLINT(*) // nullptr handling. // disallow implicit conversion as 0 can be implicitly converted to nullptr_t explicit NestedMsg(std::nullptr_t) {} @@ -176,7 +176,7 @@ class NestedMsg : public ObjectRef { bool IsNull() const { return data_ == nullptr; } /*! \return Whether the nested message is nested */ - bool IsNested() const { return data_ != nullptr && data_->IsInstance(); } + bool IsNested() const { return data_ != nullptr && data_->IsInstance(); } /*! * \return The underlying leaf value. diff --git a/include/tvm/relax/tir_pattern.h b/include/tvm/relax/tir_pattern.h index 02634dcbbf71..1b29cb03582d 100644 --- a/include/tvm/relax/tir_pattern.h +++ b/include/tvm/relax/tir_pattern.h @@ -66,7 +66,7 @@ class MatchResult : public ObjectRef { TVM_DLL explicit MatchResult(TIRPattern pattern, Array symbol_values, Array matched_buffers); - TVM_DEFINE_OBJECT_REF_METHODS(MatchResult, ObjectRef, MatchResultNode) + TVM_DEFINE_OBJECT_REF_METHODS(MatchResult, ObjectRef, MatchResultNode); }; using FCodegen = runtime::TypedPackedFunc(Array match_results)>; diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index eaad44a93ace..2da2ba53c701 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -38,6 +38,7 @@ using PassInfo = tvm::transform::PassInfo; using PassContext = tvm::transform::PassContext; using Function = tvm::relax::Function; using DataflowBlock = tvm::relax::DataflowBlock; +using tvm::transform::CreateModulePass; /*! * \brief Create a function pass. @@ -50,9 +51,9 @@ using DataflowBlock = tvm::relax::DataflowBlock; * * \return The created function pass. */ -TVM_DLL Pass CreateFunctionPass( - const runtime::TypedPackedFunc& pass_func, - int opt_level, String name, tvm::Array required, bool traceable = false); +TVM_DLL Pass CreateFunctionPass(std::function pass_func, + int opt_level, String name, tvm::Array required, + bool traceable = false); /*! * \brief Create a dataflowblock pass. @@ -66,8 +67,8 @@ TVM_DLL Pass CreateFunctionPass( * \return The created dataflowblock pass. */ TVM_DLL Pass CreateDataflowBlockPass( - const runtime::TypedPackedFunc& pass_func, - int opt_level, String name, tvm::Array required, bool traceable = false); + std::function pass_func, int opt_level, + String name, tvm::Array required, bool traceable = false); /*! * \brief Perform lambda lifting to lift functions from nested into global. @@ -388,7 +389,7 @@ class FusionPatternNode : public Object { * \brief The function to get attributes for fused function * * It should have signature - * Map(const Map& context) + * Map(const Map& context) */ Optional attrs_getter; @@ -546,7 +547,7 @@ TVM_DLL Pass FuseTIR(); * \param entry_functions list of entry functions * \return The Pass. */ -TVM_DLL Pass RunCodegen(Optional>> target_options, +TVM_DLL Pass RunCodegen(Optional>> target_options, Array entry_functions); /*! @@ -585,8 +586,8 @@ TVM_DLL Pass DecomposeOpsForTraining(Optional func_name); */ TVM_DLL Pass AlterOpImpl(const Map& op_impl_map, const Map>& op_buffer_transforms, - const Map>>& axis_separators, - const Map>>& input_axis_separators); + const Map>>>& axis_separators, + const Map>>>& input_axis_separators); /*! * \brief Layout conversion pass. diff --git a/include/tvm/relax/tuning_api.h b/include/tvm/relax/tuning_api.h index b6224a6d6d9e..bcbfad2c7ac9 100644 --- a/include/tvm/relax/tuning_api.h +++ b/include/tvm/relax/tuning_api.h @@ -27,25 +27,23 @@ #include #include +#include #include + namespace tvm { namespace relax { /*! \brief Helper function to unpack arguments in the array as parameters for the given packed * function. */ TVM_ALWAYS_INLINE TVMRetValue CallPackedWithArgsInArray(const runtime::PackedFunc f, - const Array& args) { + const Array& args) { size_t num_args = args.size(); - std::vector values(num_args); - std::vector codes(num_args); - runtime::TVMArgsSetter setter(values.data(), codes.data()); - const ObjectRef* ptr = args.template as()->begin(); + std::vector packed_args(num_args); for (size_t i = 0; i < num_args; ++i) { - setter(i, *(ptr + i)); + packed_args[i] = args[i]; } - - TVMRetValue rv; - f.CallPacked(TVMArgs(values.data(), codes.data(), num_args), &rv); + Any rv; + f.CallPacked(ffi::PackedArgs(packed_args.data(), packed_args.size()), &rv); return rv; } @@ -56,8 +54,8 @@ class ChoiceNode : public runtime::Object { String transform_func_key; /*! \brief ffi key for constraint function. */ String constr_func_key; - Array transform_func_args; - Array constr_func_args; + Array transform_func_args; + Array constr_func_args; /*! \brief The default destructor. */ virtual ~ChoiceNode() = default; @@ -71,33 +69,33 @@ class ChoiceNode : public runtime::Object { /*! \brief Getter for constr_func. */ const runtime::PackedFunc GetConstrFunc() { - const auto* constr_func = tvm::runtime::Registry::Get(constr_func_key); - ICHECK(constr_func != nullptr) << "constr_func_key is not registered: " << constr_func_key; - return *constr_func; + const auto constr_func = tvm::ffi::Function::GetGlobal(constr_func_key); + ICHECK(constr_func.has_value()) << "constr_func_key is not registered: " << constr_func_key; + return *std::move(constr_func); } /*! \brief Getter for transform_func. */ const runtime::PackedFunc GetTransformFunc() { - auto* transform_func = tvm::runtime::Registry::Get(transform_func_key); - ICHECK(transform_func != nullptr) + auto transform_func = tvm::ffi::Function::GetGlobal(transform_func_key); + ICHECK(transform_func.has_value()) << "transform_func_key is not registered: " << transform_func_key; - return *transform_func; + return *std::move(transform_func); } /*! \brief Perform constr_func. */ - bool CheckConstr(const IRModule& mod) { - Array args(constr_func_args); - args.insert(args.begin(), mod); - return CallPackedWithArgsInArray(GetConstrFunc(), args); + bool CheckConstr(IRModule mod) { + Array args(constr_func_args); + args.insert(args.begin(), GetRef(mod.CopyOnWrite())); + return CallPackedWithArgsInArray(GetConstrFunc(), args).cast(); } /*! \brief Perform transform_func. */ IRModule ApplyTransformFunc(IRModule mod) { // Apply transformation when constraint is satisfied. if (CheckConstr(mod)) { - Array args(transform_func_args); + Array args(transform_func_args); args.insert(args.begin(), GetRef(mod.CopyOnWrite())); - return CallPackedWithArgsInArray(GetTransformFunc(), args); + return CallPackedWithArgsInArray(GetTransformFunc(), args).cast(); } return mod; } @@ -115,8 +113,8 @@ class ChoiceNode : public runtime::Object { /*! \brief Managed reference to ChoiceNode */ class Choice : public runtime::ObjectRef { public: - TVM_DLL explicit Choice(String transform_func_key, Array transform_func_args, - String constr_func_key, Array constr_func_args); + TVM_DLL explicit Choice(String transform_func_key, Array transform_func_args, + String constr_func_key, Array constr_func_args); /*! \brief Deserialize JSON-style object into Choice */ TVM_DLL static Choice FromJSON(const ObjectRef& json_obj); TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Choice, ObjectRef, ChoiceNode); diff --git a/include/tvm/runtime/container/array.h b/include/tvm/runtime/container/array.h index ba8fdfac5565..7d8de1c23423 100644 --- a/include/tvm/runtime/container/array.h +++ b/include/tvm/runtime/container/array.h @@ -18,902 +18,24 @@ */ /*! - * \file tvm/runtime/container/array.h - * \brief Runtime Array container types. + * \file tvm/runtime/container/string.h + * \brief Runtime String container types. */ #ifndef TVM_RUNTIME_CONTAINER_ARRAY_H_ #define TVM_RUNTIME_CONTAINER_ARRAY_H_ -#include -#include -#include -#include -#include - -#include "./base.h" -#include "./optional.h" +#include namespace tvm { namespace runtime { -/*! \brief array node content in array */ -class ArrayNode : public Object, public InplaceArrayBase { - public: - /*! \return The size of the array */ - size_t size() const { return this->size_; } - - /*! - * \brief Read i-th element from array. - * \param i The index - * \return the i-th element. - */ - const ObjectRef at(int64_t i) const { return this->operator[](i); } - - /*! \return begin constant iterator */ - const ObjectRef* begin() const { return static_cast(InplaceArrayBase::AddressOf(0)); } - - /*! \return end constant iterator */ - const ObjectRef* end() const { return begin() + size_; } - - /*! \brief Release reference to all the elements */ - void clear() { ShrinkBy(size_); } - - /*! - * \brief Set i-th element of the array in-place - * \param i The index - * \param item The value to be set - */ - void SetItem(int64_t i, ObjectRef item) { this->operator[](i) = std::move(item); } - - /*! - * \brief Constructs a container and copy from another - * \param cap The capacity of the container - * \param from Source of the copy - * \return Ref-counted ArrayNode requested - */ - static ObjectPtr CopyFrom(int64_t cap, ArrayNode* from) { - int64_t size = from->size_; - ICHECK_GE(cap, size) << "ValueError: not enough capacity"; - ObjectPtr p = ArrayNode::Empty(cap); - ObjectRef* write = p->MutableBegin(); - ObjectRef* read = from->MutableBegin(); - // To ensure exception safety, size is only incremented after the initialization succeeds - for (int64_t& i = p->size_ = 0; i < size; ++i) { - new (write++) ObjectRef(*read++); - } - return p; - } - - /*! - * \brief Constructs a container and move from another - * \param cap The capacity of the container - * \param from Source of the move - * \return Ref-counted ArrayNode requested - */ - static ObjectPtr MoveFrom(int64_t cap, ArrayNode* from) { - int64_t size = from->size_; - ICHECK_GE(cap, size) << "ValueError: not enough capacity"; - ObjectPtr p = ArrayNode::Empty(cap); - ObjectRef* write = p->MutableBegin(); - ObjectRef* read = from->MutableBegin(); - // To ensure exception safety, size is only incremented after the initialization succeeds - for (int64_t& i = p->size_ = 0; i < size; ++i) { - new (write++) ObjectRef(std::move(*read++)); - } - from->size_ = 0; - return p; - } - - /*! - * \brief Constructs a container with n elements. Each element is a copy of val - * \param n The size of the container - * \param val The init value - * \return Ref-counted ArrayNode requested - */ - static ObjectPtr CreateRepeated(int64_t n, const ObjectRef& val) { - ObjectPtr p = ArrayNode::Empty(n); - ObjectRef* itr = p->MutableBegin(); - for (int64_t& i = p->size_ = 0; i < n; ++i) { - new (itr++) ObjectRef(val); - } - return p; - } - - static constexpr const uint32_t _type_index = TypeIndex::kRuntimeArray; - static constexpr const char* _type_key = "Array"; - TVM_DECLARE_FINAL_OBJECT_INFO(ArrayNode, Object); - - private: - /*! \return Size of initialized memory, used by InplaceArrayBase. */ - size_t GetSize() const { return this->size_; } - - /*! \return begin mutable iterator */ - ObjectRef* MutableBegin() const { - return static_cast(InplaceArrayBase::AddressOf(0)); - } - - /*! \return end mutable iterator */ - ObjectRef* MutableEnd() const { return MutableBegin() + size_; } - - /*! - * \brief Create an ArrayNode with the given capacity. - * \param n Required capacity - * \return Ref-counted ArrayNode requested - */ - static ObjectPtr Empty(int64_t n = kInitSize) { - ICHECK_GE(n, 0); - ObjectPtr p = make_inplace_array_object(n); - p->capacity_ = n; - p->size_ = 0; - return p; - } - - /*! - * \brief Inplace-initialize the elements starting idx from [first, last) - * \param idx The starting point - * \param first Begin of iterator - * \param last End of iterator - * \tparam IterType The type of iterator - * \return Self - */ - template - ArrayNode* InitRange(int64_t idx, IterType first, IterType last) { - ObjectRef* itr = MutableBegin() + idx; - for (; first != last; ++first) { - ObjectRef ref = *first; - new (itr++) ObjectRef(std::move(ref)); - } - return this; - } - - /*! - * \brief Move elements from right to left, requires src_begin > dst - * \param dst Destination - * \param src_begin The start point of copy (inclusive) - * \param src_end The end point of copy (exclusive) - * \return Self - */ - ArrayNode* MoveElementsLeft(int64_t dst, int64_t src_begin, int64_t src_end) { - ObjectRef* from = MutableBegin() + src_begin; - ObjectRef* to = MutableBegin() + dst; - while (src_begin++ != src_end) { - *to++ = std::move(*from++); - } - return this; - } - - /*! - * \brief Move elements from left to right, requires src_begin < dst - * \param dst Destination - * \param src_begin The start point of move (inclusive) - * \param src_end The end point of move (exclusive) - * \return Self - */ - ArrayNode* MoveElementsRight(int64_t dst, int64_t src_begin, int64_t src_end) { - ObjectRef* from = MutableBegin() + src_end; - ObjectRef* to = MutableBegin() + (src_end - src_begin + dst); - while (src_begin++ != src_end) { - *--to = std::move(*--from); - } - return this; - } - - /*! - * \brief Enlarges the size of the array - * \param delta Size enlarged, should be positive - * \param val Default value - * \return Self - */ - ArrayNode* EnlargeBy(int64_t delta, const ObjectRef& val = ObjectRef(nullptr)) { - ObjectRef* itr = MutableEnd(); - while (delta-- > 0) { - new (itr++) ObjectRef(val); - ++size_; - } - return this; - } - - /*! - * \brief Shrinks the size of the array - * \param delta Size shrinked, should be positive - * \return Self - */ - ArrayNode* ShrinkBy(int64_t delta) { - ObjectRef* itr = MutableEnd(); - while (delta-- > 0) { - (--itr)->ObjectRef::~ObjectRef(); - --size_; - } - return this; - } - - /*! \brief Number of elements used */ - int64_t size_; - - /*! \brief Number of elements allocated */ - int64_t capacity_; - - /*! \brief Initial size of ArrayNode */ - static constexpr int64_t kInitSize = 4; - - /*! \brief Expansion factor of the Array */ - static constexpr int64_t kIncFactor = 2; - - // CRTP parent class - friend InplaceArrayBase; - - // Reference class - template - friend class Array; - - // To specialize make_object - friend ObjectPtr make_object<>(); -}; - -/*! \brief Helper struct for type-checking - * - * is_valid_iterator::value will be true if IterType can - * be dereferenced into a type that can be stored in an Array, and - * false otherwise. - */ -template -struct is_valid_iterator - : std::bool_constant())>>>> {}; - -template -struct is_valid_iterator, IterType> : is_valid_iterator {}; - -template -inline constexpr bool is_valid_iterator_v = is_valid_iterator::value; - -/*! - * \brief Array, container representing a contiguous sequence of ObjectRefs. - * - * Array implements in-place copy-on-write semantics. - * - * As in typical copy-on-write, a method which would typically mutate the array - * instead opaquely copies the underlying container, and then acts on its copy. - * - * If the array has reference count equal to one, we directly update the - * container in place without copying. This is optimization is sound because - * when the reference count is equal to one this reference is guranteed to be - * the sole pointer to the container. - * - * - * operator[] only provides const access, use Set to mutate the content. - * \tparam T The content ObjectRef type. - */ -template ::value>::type> -class Array : public ObjectRef { - public: - using value_type = T; - // constructors - /*! - * \brief default constructor - */ - Array() { data_ = ArrayNode::Empty(); } - - /*! - * \brief move constructor - * \param other source - */ - Array(Array&& other) : ObjectRef() { // NOLINT(*) - data_ = std::move(other.data_); - } - - /*! - * \brief copy constructor - * \param other source - */ - Array(const Array& other) : ObjectRef() { // NOLINT(*) - data_ = other.data_; - } - - /*! - * \brief constructor from pointer - * \param n the container pointer - */ - explicit Array(ObjectPtr n) : ObjectRef(n) {} - - /*! - * \brief Constructor from iterator - * \param first begin of iterator - * \param last end of iterator - * \tparam IterType The type of iterator - */ - template - Array(IterType first, IterType last) { - static_assert(is_valid_iterator_v, - "IterType cannot be inserted into a tvm::Array"); - Assign(first, last); - } - - /*! - * \brief constructor from initializer list - * \param init The initializer list - */ - Array(std::initializer_list init) { // NOLINT(*) - Assign(init.begin(), init.end()); - } - - /*! - * \brief constructor from vector - * \param init The vector - */ - Array(const std::vector& init) { // NOLINT(*) - Assign(init.begin(), init.end()); - } - - /*! - * \brief Constructs a container with n elements. Each element is a copy of val - * \param n The size of the container - * \param val The init value - */ - explicit Array(const size_t n, const T& val) { data_ = ArrayNode::CreateRepeated(n, val); } - - /*! - * \brief move assign operator - * \param other The source of assignment - * \return reference to self. - */ - Array& operator=(Array&& other) { - data_ = std::move(other.data_); - return *this; - } - - /*! - * \brief copy assign operator - * \param other The source of assignment - * \return reference to self. - */ - Array& operator=(const Array& other) { - data_ = other.data_; - return *this; - } - - public: - // iterators - struct ValueConverter { - using ResultType = T; - static T convert(const ObjectRef& n) { return DowncastNoCheck(n); } - }; - - using iterator = IterAdapter; - using reverse_iterator = ReverseIterAdapter; - - /*! \return begin iterator */ - iterator begin() const { return iterator(GetArrayNode()->begin()); } - - /*! \return end iterator */ - iterator end() const { return iterator(GetArrayNode()->end()); } - - /*! \return rbegin iterator */ - reverse_iterator rbegin() const { - // ArrayNode::end() is never nullptr - return reverse_iterator(GetArrayNode()->end() - 1); - } - - /*! \return rend iterator */ - reverse_iterator rend() const { - // ArrayNode::begin() is never nullptr - return reverse_iterator(GetArrayNode()->begin() - 1); - } - - public: - // const methods in std::vector - /*! - * \brief Immutably read i-th element from array. - * \param i The index - * \return the i-th element. - */ - const T operator[](int64_t i) const { - ArrayNode* p = GetArrayNode(); - ICHECK(p != nullptr) << "ValueError: cannot index a null array"; - ICHECK(0 <= i && i < p->size_) - << "IndexError: indexing " << i << " on an array of size " << p->size_; - return DowncastNoCheck(*(p->begin() + i)); - } - - /*! \return The size of the array */ - size_t size() const { - ArrayNode* p = GetArrayNode(); - return p == nullptr ? 0 : GetArrayNode()->size_; - } - - /*! \return The capacity of the array */ - size_t capacity() const { - ArrayNode* p = GetArrayNode(); - return p == nullptr ? 0 : GetArrayNode()->capacity_; - } - - /*! \return Whether array is empty */ - bool empty() const { return size() == 0; } - - /*! \return The first element of the array */ - const T front() const { - ArrayNode* p = GetArrayNode(); - ICHECK(p != nullptr) << "ValueError: cannot index a null array"; - ICHECK_GT(p->size_, 0) << "IndexError: cannot index an empty array"; - return DowncastNoCheck(*(p->begin())); - } - - /*! \return The last element of the array */ - const T back() const { - ArrayNode* p = GetArrayNode(); - ICHECK(p != nullptr) << "ValueError: cannot index a null array"; - ICHECK_GT(p->size_, 0) << "IndexError: cannot index an empty array"; - return DowncastNoCheck(*(p->end() - 1)); - } - - public: - // mutation in std::vector, implements copy-on-write - - /*! - * \brief push a new item to the back of the list - * \param item The item to be pushed. - */ - void push_back(const T& item) { - ArrayNode* p = CopyOnWrite(1); - p->EmplaceInit(p->size_++, item); - } - - /*! - * \brief Insert an element into the given position - * \param position An iterator pointing to the insertion point - * \param val The element to insert - */ - void insert(iterator position, const T& val) { - ICHECK(data_ != nullptr) << "ValueError: cannot insert a null array"; - int64_t idx = std::distance(begin(), position); - int64_t size = GetArrayNode()->size_; - auto addr = CopyOnWrite(1) // - ->EnlargeBy(1) // - ->MoveElementsRight(idx + 1, idx, size) // - ->MutableBegin(); - new (addr + idx) ObjectRef(val); - } - - /*! - * \brief Insert a range of elements into the given position - * \param position An iterator pointing to the insertion point - * \param first The begin iterator of the range - * \param last The end iterator of the range - */ - template - void insert(iterator position, IterType first, IterType last) { - static_assert(is_valid_iterator_v, - "IterType cannot be inserted into a tvm::Array"); - - if (first == last) { - return; - } - ICHECK(data_ != nullptr) << "ValueError: cannot insert a null array"; - int64_t idx = std::distance(begin(), position); - int64_t size = GetArrayNode()->size_; - int64_t numel = std::distance(first, last); - CopyOnWrite(numel) - ->EnlargeBy(numel) - ->MoveElementsRight(idx + numel, idx, size) - ->InitRange(idx, first, last); - } - - /*! \brief Remove the last item of the list */ - void pop_back() { - ICHECK(data_ != nullptr) << "ValueError: cannot pop_back because array is null"; - int64_t size = GetArrayNode()->size_; - ICHECK_GT(size, 0) << "ValueError: cannot pop_back because array is empty"; - CopyOnWrite()->ShrinkBy(1); - } - - /*! - * \brief Erase an element on the given position - * \param position An iterator pointing to the element to be erased - */ - void erase(iterator position) { - ICHECK(data_ != nullptr) << "ValueError: cannot erase a null array"; - int64_t st = std::distance(begin(), position); - int64_t size = GetArrayNode()->size_; - ICHECK(0 <= st && st < size) << "ValueError: cannot erase at index " << st - << ", because Array size is " << size; - CopyOnWrite() // - ->MoveElementsLeft(st, st + 1, size) // - ->ShrinkBy(1); - } - - /*! - * \brief Erase a given range of elements - * \param first The begin iterator of the range - * \param last The end iterator of the range - */ - void erase(iterator first, iterator last) { - if (first == last) { - return; - } - ICHECK(data_ != nullptr) << "ValueError: cannot erase a null array"; - int64_t size = GetArrayNode()->size_; - int64_t st = std::distance(begin(), first); - int64_t ed = std::distance(begin(), last); - ICHECK_LT(st, ed) << "ValueError: cannot erase array in range [" << st << ", " << ed << ")"; - ICHECK(0 <= st && st <= size && 0 <= ed && ed <= size) - << "ValueError: cannot erase array in range [" << st << ", " << ed << ")" - << ", because array size is " << size; - CopyOnWrite() // - ->MoveElementsLeft(st, ed, size) // - ->ShrinkBy(ed - st); - } - - /*! - * \brief Resize the array. - * \param n The new size. - */ - void resize(int64_t n) { - ICHECK_GE(n, 0) << "ValueError: cannot resize an Array to negative size"; - if (data_ == nullptr) { - SwitchContainer(n); - return; - } - int64_t size = GetArrayNode()->size_; - if (size < n) { - CopyOnWrite(n - size)->EnlargeBy(n - size); - } else if (size > n) { - CopyOnWrite()->ShrinkBy(size - n); - } - } - - /*! - * \brief Make sure the list has the capacity of at least n - * \param n lower bound of the capacity - */ - void reserve(int64_t n) { - if (data_ == nullptr || n > GetArrayNode()->capacity_) { - SwitchContainer(n); - } - } - - /*! \brief Release reference to all the elements */ - void clear() { - if (data_ != nullptr) { - ArrayNode* p = CopyOnWrite(); - p->clear(); - } - } - - template - static size_t CalcCapacityImpl() { - return 0; - } - - template - static size_t CalcCapacityImpl(Array value, Args... args) { - return value.size() + CalcCapacityImpl(args...); - } - - template - static size_t CalcCapacityImpl(T value, Args... args) { - return 1 + CalcCapacityImpl(args...); - } - - template - static void AgregateImpl(Array& dest) {} // NOLINT(*) - - template - static void AgregateImpl(Array& dest, Array value, Args... args) { // NOLINT(*) - dest.insert(dest.end(), value.begin(), value.end()); - AgregateImpl(dest, args...); - } - - template - static void AgregateImpl(Array& dest, T value, Args... args) { // NOLINT(*) - dest.push_back(value); - AgregateImpl(dest, args...); - } - - public: - // Array's own methods - - /*! - * \brief set i-th element of the array. - * \param i The index - * \param value The value to be setted. - */ - void Set(int64_t i, T value) { - ArrayNode* p = this->CopyOnWrite(); - ICHECK(0 <= i && i < p->size_) - << "IndexError: indexing " << i << " on an array of size " << p->size_; - *(p->MutableBegin() + i) = std::move(value); - } - - /*! \return The underlying ArrayNode */ - ArrayNode* GetArrayNode() const { return static_cast(data_.get()); } - - /*! - * \brief Helper function to apply a map function onto the array. - * - * \param fmap The transformation function T -> U. - * - * \tparam F The type of the mutation function. - * - * \tparam U The type of the returned array, inferred from the - * return type of F. If overridden by the user, must be something - * that is convertible from the return type of F. - * - * \note This function performs copy on write optimization. If - * `fmap` returns an object of type `T`, and all elements of the - * array are mapped to themselves, then the returned array will be - * the same as the original, and reference counts of the elements in - * the array will not be incremented. - * - * \return The transformed array. - */ - template > - Array Map(F fmap) const { - return Array(MapHelper(data_, fmap)); - } - - /*! - * \brief Helper function to apply fmutate to mutate an array. - * \param fmutate The transformation function T -> T. - * \tparam F the type of the mutation function. - * \note This function performs copy on write optimization. - */ - template >>> - void MutateByApply(F fmutate) { - data_ = MapHelper(std::move(data_), fmutate); - } - - /*! - * \brief reset the array to content from iterator. - * \param first begin of iterator - * \param last end of iterator - * \tparam IterType The type of iterator - */ - template - void Assign(IterType first, IterType last) { - int64_t cap = std::distance(first, last); - ICHECK_GE(cap, 0) << "ValueError: cannot construct an Array of negative size"; - ArrayNode* p = GetArrayNode(); - if (p != nullptr && data_.unique() && p->capacity_ >= cap) { - // do not have to make new space - p->clear(); - } else { - // create new space - data_ = ArrayNode::Empty(cap); - p = GetArrayNode(); - } - // To ensure exception safety, size is only incremented after the initialization succeeds - ObjectRef* itr = p->MutableBegin(); - for (int64_t& i = p->size_ = 0; i < cap; ++i, ++first, ++itr) { - new (itr) ObjectRef(*first); - } - } - - /*! - * \brief Copy on write semantics - * Do nothing if current handle is the unique copy of the array. - * Otherwise make a new copy of the array to ensure the current handle - * hold a unique copy. - * - * \return Handle to the internal node container(which ganrantees to be unique) - */ - ArrayNode* CopyOnWrite() { - if (data_ == nullptr) { - return SwitchContainer(ArrayNode::kInitSize); - } - if (!data_.unique()) { - return SwitchContainer(capacity()); - } - return static_cast(data_.get()); - } - - /*! \brief specify container node */ - using ContainerType = ArrayNode; - - /*! - * \brief Agregate arguments into a single Array - * \param args sequence of T or Array elements - * \return Agregated Array - */ - template - static Array Agregate(Args... args) { - Array result; - result.reserve(CalcCapacityImpl(args...)); - AgregateImpl(result, args...); - return result; - } - - private: - /*! - * \brief Implement copy-on-write semantics, and ensures capacity is enough for extra elements. - * \param reserve_extra Number of extra slots needed - * \return ArrayNode pointer to the unique copy - */ - ArrayNode* CopyOnWrite(int64_t reserve_extra) { - ArrayNode* p = GetArrayNode(); - if (p == nullptr) { - // necessary to get around the constexpr address issue before c++17 - const int64_t kInitSize = ArrayNode::kInitSize; - return SwitchContainer(std::max(kInitSize, reserve_extra)); - } - if (p->capacity_ >= p->size_ + reserve_extra) { - return CopyOnWrite(); - } - int64_t cap = p->capacity_ * ArrayNode::kIncFactor; - cap = std::max(cap, p->size_ + reserve_extra); - return SwitchContainer(cap); - } - - /*! - * \brief Move or copy the ArrayNode to new address with the given capacity - * \param capacity The capacity requirement of the new address - */ - ArrayNode* SwitchContainer(int64_t capacity) { - if (data_ == nullptr) { - data_ = ArrayNode::Empty(capacity); - } else if (data_.unique()) { - data_ = ArrayNode::MoveFrom(capacity, GetArrayNode()); - } else { - data_ = ArrayNode::CopyFrom(capacity, GetArrayNode()); - } - return static_cast(data_.get()); - } - - /*! \brief Helper method for mutate/map - * - * A helper function used internally by both `Array::Map` and - * `Array::MutateInPlace`. Given an array of data, apply the - * mapping function to each element, returning the collected array. - * Applies both mutate-in-place and copy-on-write optimizations, if - * possible. - * - * \param data A pointer to the ArrayNode containing input data. - * Passed by value to allow for mutate-in-place optimizations. - * - * \param fmap The mapping function - * - * \tparam F The type of the mutation function. - * - * \tparam U The output type of the mutation function. Inferred - * from the callable type given. Must inherit from ObjectRef. - * - * \return The mapped array. Depending on whether mutate-in-place - * or copy-on-write optimizations were applicable, may be the same - * underlying array as the `data` parameter. - */ - template > - static ObjectPtr MapHelper(ObjectPtr data, F fmap) { - if (data == nullptr) { - return nullptr; - } - - ICHECK(data->IsInstance()); - - constexpr bool is_same_output_type = std::is_same_v; - - if constexpr (is_same_output_type) { - if (data.unique()) { - // Mutate-in-place path. Only allowed if the output type U is - // the same as type T, we have a mutable this*, and there are - // no other shared copies of the array. - auto arr = static_cast(data.get()); - for (auto it = arr->MutableBegin(); it != arr->MutableEnd(); it++) { - T mapped = fmap(DowncastNoCheck(std::move(*it))); - *it = std::move(mapped); - } - return data; - } - } - - constexpr bool compatible_types = is_valid_iterator_v || is_valid_iterator_v; - - ObjectPtr output = nullptr; - auto arr = static_cast(data.get()); - - auto it = arr->begin(); - if constexpr (compatible_types) { - // Copy-on-write path, if the output Array might be - // represented by the same underlying array as the existing - // Array. Typically, this is for functions that map `T` to - // `T`, but can also apply to functions that map `T` to - // `Optional`, or that map `T` to a subclass or superclass of - // `T`. - bool all_identical = true; - for (; it != arr->end(); it++) { - U mapped = fmap(DowncastNoCheck(*it)); - if (!mapped.same_as(*it)) { - // At least one mapped element is different than the - // original. Therefore, prepare the output array, - // consisting of any previous elements that had mapped to - // themselves (if any), and the element that didn't map to - // itself. - // - // We cannot use `U()` as the default object, as `U` may be - // a non-nullable type. Since the default `ObjectRef()` - // will be overwritten before returning, all objects will be - // of type `U` for the calling scope. - all_identical = false; - output = ArrayNode::CreateRepeated(arr->size(), ObjectRef()); - output->InitRange(0, arr->begin(), it); - output->SetItem(it - arr->begin(), std::move(mapped)); - it++; - break; - } - } - if (all_identical) { - return data; - } - } else { - // Path for incompatible types. The constexpr check for - // compatible types isn't strictly necessary, as the first - // mapped.same_as(*it) would return false, but we might as well - // avoid it altogether. - // - // We cannot use `U()` as the default object, as `U` may be a - // non-nullable type. Since the default `ObjectRef()` will be - // overwritten before returning, all objects will be of type `U` - // for the calling scope. - output = ArrayNode::CreateRepeated(arr->size(), ObjectRef()); - } - - // Normal path for incompatible types, or post-copy path for - // copy-on-write instances. - // - // If the types are incompatible, then at this point `output` is - // empty, and `it` points to the first element of the input. - // - // If the types were compatible, then at this point `output` - // contains zero or more elements that mapped to themselves - // followed by the first element that does not map to itself, and - // `it` points to the element just after the first element that - // does not map to itself. Because at least one element has been - // changed, we no longer have the opportunity to avoid a copy, so - // we don't need to check the result. - // - // In both cases, `it` points to the next element to be processed, - // so we can either start or resume the iteration from that point, - // with no further checks on the result. - for (; it != arr->end(); it++) { - U mapped = fmap(DowncastNoCheck(*it)); - output->SetItem(it - arr->begin(), std::move(mapped)); - } - - return output; - } -}; - -template -inline constexpr bool is_tvm_array = false; - -template -inline constexpr bool is_tvm_array> = true; - -/*! - * \brief Concat two Arrays. - * \param lhs first Array to be concatenated. - * \param rhs second Array to be concatenated. - * \return The concatenated Array. Original Arrays are kept unchanged. - */ -template ::value>::type> -inline Array Concat(Array lhs, const Array& rhs) { - for (const auto& x : rhs) { - lhs.push_back(x); - } - return std::move(lhs); -} - -// Specialize make_object to make sure it is correct. -template <> -inline ObjectPtr make_object() { - return ArrayNode::Empty(); -} +using tvm::ffi::Array; +using tvm::ffi::ArrayObj; } // namespace runtime -// expose the functions to the root namespace. -using runtime::Array; -using runtime::ArrayNode; +// expose class to root namespace +using tvm::ffi::Array; +using tvm::ffi::ArrayObj; } // namespace tvm - #endif // TVM_RUNTIME_CONTAINER_ARRAY_H_ diff --git a/include/tvm/runtime/container/base.h b/include/tvm/runtime/container/base.h index 51d48ae7f23b..b0295761f6a3 100644 --- a/include/tvm/runtime/container/base.h +++ b/include/tvm/runtime/container/base.h @@ -35,27 +35,6 @@ namespace tvm { namespace runtime { -/*! \brief String-aware ObjectRef equal functor */ -struct ObjectHash { - /*! - * \brief Calculate the hash code of an ObjectRef - * \param a The given ObjectRef - * \return Hash code of a, string hash for strings and pointer address otherwise. - */ - size_t operator()(const ObjectRef& a) const; -}; - -/*! \brief String-aware ObjectRef hash functor */ -struct ObjectEqual { - /*! - * \brief Check if the two ObjectRef are equal - * \param a One ObjectRef - * \param b The other ObjectRef - * \return String equality if both are strings, pointer address equality otherwise. - */ - bool operator()(const ObjectRef& a, const ObjectRef& b) const; -}; - /*! * \brief Base template for classes with array like memory layout. * @@ -290,8 +269,6 @@ using runtime::Downcast; using runtime::IterAdapter; using runtime::make_object; using runtime::Object; -using runtime::ObjectEqual; -using runtime::ObjectHash; using runtime::ObjectPtr; using runtime::ObjectPtrEqual; using runtime::ObjectPtrHash; diff --git a/include/tvm/runtime/container/boxed_primitive.h b/include/tvm/runtime/container/boxed_primitive.h deleted file mode 100644 index 8d01b5dc17b5..000000000000 --- a/include/tvm/runtime/container/boxed_primitive.h +++ /dev/null @@ -1,143 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/runtime/container/boxed_primitive.h - * \brief Runtime container types for primitives stored as ObjectRef. - */ -#ifndef TVM_RUNTIME_CONTAINER_BOXED_PRIMITIVE_H_ -#define TVM_RUNTIME_CONTAINER_BOXED_PRIMITIVE_H_ - -#include -#include - -namespace tvm { -namespace runtime { - -namespace detail { -/* \brief Provide the BoxNode type traits in templated contexts - * - * The Box class is used in many templated contexts, and is easier - * to have templated over the primitive type. - * - * However, much of the TVM type system depends on classes having a - * unique name. For example, the use of `Object::IsInstance` depends - * on `Object::GetOrAllocRuntimeTypeIndex`. Any duplicate names will - * result in duplicate indices, and invalid downcasting. Furthermore, - * the name must be specified in the Python FFI using - * `tvm._ffi.register_object`. This prevents use of - * `typeid(T)::name()` to build a unique name, as the name is not - * required to be human-readable or consistent across compilers. - * - * This utility struct should be specialized over the primitive type - * held by the box, to allow explicit listing of the `_type_key` and - * other similar tratis. - * - * Note: This should only contain traits that are required at runtime, - * and should *not* contain extensions for features that are only - * available at compile-time. For integration with compile-time-only - * functionality (e.g. StructuralHash, StructuralEqual), see - * `BoxNodeCompileTimeTraits` in `src/node/boxed_primitive.cc`. - */ -template -struct BoxNodeRuntimeTraits; - -} // namespace detail - -template -class BoxNode : public Object { - public: - /*! \brief Constructor - * - * \param value The value to be boxed - */ - explicit BoxNode(Prim value) : value(value) {} - - /*! \brief The boxed value */ - Prim value; - - static constexpr const char* _type_key = detail::BoxNodeRuntimeTraits::_type_key; - static constexpr bool _type_has_method_visit_attrs = false; - TVM_DECLARE_FINAL_OBJECT_INFO(BoxNode, Object); -}; - -template -class Box : public ObjectRef { - public: - /*! \brief Constructor - * - * \param value The value to be boxed - */ - Box(Prim value) : ObjectRef(make_object>(value)) {} // NOLINT(*) - - operator Prim() const { return (*this)->value; } - - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Box, ObjectRef, BoxNode); -}; - -/*! \brief Boxed version of C++ int64_t - * - * Can be used to store POD integer values as a TVM ObjectRef. Used - * for FFI handling, and for storing POD types inside TVM containers. - */ -using Int = Box; - -/*! \brief Boxed version of C++ double - * - * Can be used to store POD floating-point values as a TVM ObjectRef. - * Used for FFI handling, and for storing POD types inside TVM - * containers. - */ -using Float = Box; - -/*! \brief Boxed version of C++ bool - * - * Can be used to store POD boolean values as a TVM ObjectRef. Used - * for FFI handling, and for storing POD types inside TVM containers. - * - * When passing from Python to C++, TVM PackedFunc conversion follow - * C++ conversion rules, and allow bool->int and int->bool - * conversions. When passing from C++ to Python, the types are - * returned as bool or int. If the C++ function uses ObjectRef to - * hold the object, a Python to C++ to Python round trip will preserve - * the distinction between bool and int. - */ -using Bool = Box; - -namespace detail { -template <> -struct BoxNodeRuntimeTraits { - static constexpr const char* _type_key = "runtime.BoxInt"; -}; - -template <> -struct BoxNodeRuntimeTraits { - static constexpr const char* _type_key = "runtime.BoxFloat"; -}; - -template <> -struct BoxNodeRuntimeTraits { - static constexpr const char* _type_key = "runtime.BoxBool"; -}; -} // namespace detail - -} // namespace runtime -} // namespace tvm - -#endif // TVM_RUNTIME_CONTAINER_BOXED_PRIMITIVE_H_ diff --git a/include/tvm/runtime/container/map.h b/include/tvm/runtime/container/map.h index eb86ddb7b8f9..cd63cc94ada0 100644 --- a/include/tvm/runtime/container/map.h +++ b/include/tvm/runtime/container/map.h @@ -24,1462 +24,17 @@ #ifndef TVM_RUNTIME_CONTAINER_MAP_H_ #define TVM_RUNTIME_CONTAINER_MAP_H_ -#ifndef USE_FALLBACK_STL_MAP -#define USE_FALLBACK_STL_MAP 0 -#endif - -#include -#include -#include - -#include "./base.h" -#include "./optional.h" +#include namespace tvm { namespace runtime { -#if TVM_DEBUG_WITH_ABI_CHANGE -#define TVM_MAP_FAIL_IF_CHANGED() \ - ICHECK(state_marker == self->state_marker) << "Concurrent modification of the Map"; -#else -#define TVM_MAP_FAIL_IF_CHANGED() -#endif // TVM_DEBUG_WITH_ABI_CHANGE - -#if (USE_FALLBACK_STL_MAP != 0) - -/*! \brief Shared content of all specializations of hash map */ -class MapNode : public Object { - public: - /*! \brief Type of the keys in the hash map */ - using key_type = ObjectRef; - /*! \brief Type of the values in the hash map */ - using mapped_type = ObjectRef; - /*! \brief Type of the actual underlying container */ - using ContainerType = std::unordered_map; - /*! \brief Iterator class */ - using iterator = ContainerType::iterator; - /*! \brief Iterator class */ - using const_iterator = ContainerType::const_iterator; - /*! \brief Type of value stored in the hash map */ - using KVType = ContainerType::value_type; - - static_assert(std::is_standard_layout::value, "KVType is not standard layout"); - static_assert(sizeof(KVType) == 16 || sizeof(KVType) == 8, "sizeof(KVType) incorrect"); - - static constexpr const uint32_t _type_index = runtime::TypeIndex::kRuntimeMap; - static constexpr const char* _type_key = "Map"; - TVM_DECLARE_FINAL_OBJECT_INFO(MapNode, Object); - - /*! - * \brief Number of elements in the SmallMapNode - * \return The result - */ - size_t size() const { return data_.size(); } - /*! - * \brief Count the number of times a key exists in the hash map - * \param key The indexing key - * \return The result, 0 or 1 - */ - size_t count(const key_type& key) const { return data_.count(key); } - /*! - * \brief Index value associated with a key, throw exception if the key does not exist - * \param key The indexing key - * \return The const reference to the value - */ - const mapped_type& at(const key_type& key) const { return data_.at(key); } - /*! - * \brief Index value associated with a key, throw exception if the key does not exist - * \param key The indexing key - * \return The mutable reference to the value - */ - mapped_type& at(const key_type& key) { return data_.at(key); } - /*! \return begin iterator */ - iterator begin() { return data_.begin(); } - /*! \return const begin iterator */ - const_iterator begin() const { return data_.begin(); } - /*! \return end iterator */ - iterator end() { return data_.end(); } - /*! \return end iterator */ - const_iterator end() const { return data_.end(); } - /*! - * \brief Index value associated with a key - * \param key The indexing key - * \return The iterator of the entry associated with the key, end iterator if not exists - */ - const_iterator find(const key_type& key) const { return data_.find(key); } - /*! - * \brief Index value associated with a key - * \param key The indexing key - * \return The iterator of the entry associated with the key, end iterator if not exists - */ - iterator find(const key_type& key) { return data_.find(key); } - /*! - * \brief Erase the entry associated with the iterator - * \param position The iterator - */ - void erase(const iterator& position) { data_.erase(position); } - /*! - * \brief Erase the entry associated with the key, do nothing if not exists - * \param key The indexing key - */ - void erase(const key_type& key) { data_.erase(key); } - /*! - * \brief Create an empty container - * \return The object created - */ - static ObjectPtr Empty() { return make_object(); } - - protected: - /*! - * \brief Create the map using contents from the given iterators. - * \param first Begin of iterator - * \param last End of iterator - * \tparam IterType The type of iterator - * \return ObjectPtr to the map created - */ - template - static ObjectPtr CreateFromRange(IterType first, IterType last) { - ObjectPtr p = make_object(); - p->data_ = ContainerType(first, last); - return p; - } - /*! - * \brief InsertMaybeReHash an entry into the given hash map - * \param kv The entry to be inserted - * \param map The pointer to the map, can be changed if re-hashing happens - */ - static void InsertMaybeReHash(const KVType& kv, ObjectPtr* map) { - MapNode* map_node = static_cast(map->get()); - map_node->data_[kv.first] = kv.second; - } - /*! - * \brief Create an empty container with elements copying from another MapNode - * \param from The source container - * \return The object created - */ - static ObjectPtr CopyFrom(MapNode* from) { - ObjectPtr p = make_object(); - p->data_ = ContainerType(from->data_.begin(), from->data_.end()); - return p; - } - /*! \brief The real container storing data */ - ContainerType data_; - template - friend class Map; -}; - -#else - -/*! \brief Shared content of all specializations of hash map */ -class MapNode : public Object { - public: - /*! \brief Type of the keys in the hash map */ - using key_type = ObjectRef; - /*! \brief Type of the values in the hash map */ - using mapped_type = ObjectRef; - /*! \brief Type of value stored in the hash map */ - using KVType = std::pair; - /*! \brief Iterator class */ - class iterator; - - static_assert(std::is_standard_layout::value, "KVType is not standard layout"); - static_assert(sizeof(KVType) == 16 || sizeof(KVType) == 8, "sizeof(KVType) incorrect"); - - static constexpr const uint32_t _type_index = runtime::TypeIndex::kRuntimeMap; - static constexpr const char* _type_key = "Map"; - TVM_DECLARE_FINAL_OBJECT_INFO(MapNode, Object); - - /*! - * \brief Number of elements in the SmallMapNode - * \return The result - */ - size_t size() const { return size_; } - /*! - * \brief Count the number of times a key exists in the hash map - * \param key The indexing key - * \return The result, 0 or 1 - */ - size_t count(const key_type& key) const; - /*! - * \brief Index value associated with a key, throw exception if the key does not exist - * \param key The indexing key - * \return The const reference to the value - */ - const mapped_type& at(const key_type& key) const; - /*! - * \brief Index value associated with a key, throw exception if the key does not exist - * \param key The indexing key - * \return The mutable reference to the value - */ - mapped_type& at(const key_type& key); - /*! \return begin iterator */ - iterator begin() const; - /*! \return end iterator */ - iterator end() const; - /*! - * \brief Index value associated with a key - * \param key The indexing key - * \return The iterator of the entry associated with the key, end iterator if not exists - */ - iterator find(const key_type& key) const; - /*! - * \brief Erase the entry associated with the iterator - * \param position The iterator - */ - void erase(const iterator& position); - /*! - * \brief Erase the entry associated with the key, do nothing if not exists - * \param key The indexing key - */ - void erase(const key_type& key) { erase(find(key)); } - - class iterator { - public: - using iterator_category = std::forward_iterator_tag; - using difference_type = int64_t; - using value_type = KVType; - using pointer = KVType*; - using reference = KVType&; -/*! \brief Default constructor */ -#if TVM_DEBUG_WITH_ABI_CHANGE - iterator() : state_marker(0), index(0), self(nullptr) {} -#else - iterator() : index(0), self(nullptr) {} -#endif // TVM_DEBUG_WITH_ABI_CHANGE - /*! \brief Compare iterators */ - bool operator==(const iterator& other) const { - TVM_MAP_FAIL_IF_CHANGED() - return index == other.index && self == other.self; - } - /*! \brief Compare iterators */ - bool operator!=(const iterator& other) const { return !(*this == other); } - /*! \brief De-reference iterators */ - pointer operator->() const; - /*! \brief De-reference iterators */ - reference operator*() const { - TVM_MAP_FAIL_IF_CHANGED() - return *((*this).operator->()); - } - /*! \brief Prefix self increment, e.g. ++iter */ - iterator& operator++(); - /*! \brief Prefix self decrement, e.g. --iter */ - iterator& operator--(); - /*! \brief Suffix self increment */ - iterator operator++(int) { - TVM_MAP_FAIL_IF_CHANGED() - iterator copy = *this; - ++(*this); - return copy; - } - /*! \brief Suffix self decrement */ - iterator operator--(int) { - TVM_MAP_FAIL_IF_CHANGED() - iterator copy = *this; - --(*this); - return copy; - } - - protected: -#if TVM_DEBUG_WITH_ABI_CHANGE - uint64_t state_marker; - /*! \brief Construct by value */ - iterator(uint64_t index, const MapNode* self) - : state_marker(self->state_marker), index(index), self(self) {} - -#else - iterator(uint64_t index, const MapNode* self) : index(index), self(self) {} -#endif // TVM_DEBUG_WITH_ABI_CHANGE - /*! \brief The position on the array */ - uint64_t index; - /*! \brief The container it points to */ - const MapNode* self; - - friend class DenseMapNode; - friend class SmallMapNode; - }; - /*! - * \brief Create an empty container - * \return The object created - */ - static inline ObjectPtr Empty(); - - protected: -#if TVM_DEBUG_WITH_ABI_CHANGE - uint64_t state_marker; -#endif // TVM_DEBUG_WITH_ABI_CHANGE - /*! - * \brief Create the map using contents from the given iterators. - * \param first Begin of iterator - * \param last End of iterator - * \tparam IterType The type of iterator - * \return ObjectPtr to the map created - */ - template - static inline ObjectPtr CreateFromRange(IterType first, IterType last); - /*! - * \brief InsertMaybeReHash an entry into the given hash map - * \param kv The entry to be inserted - * \param map The pointer to the map, can be changed if re-hashing happens - */ - static inline void InsertMaybeReHash(const KVType& kv, ObjectPtr* map); - /*! - * \brief Create an empty container with elements copying from another SmallMapNode - * \param from The source container - * \return The object created - */ - static inline ObjectPtr CopyFrom(MapNode* from); - /*! \brief number of slots minus 1 */ - uint64_t slots_; - /*! \brief number of entries in the container */ - uint64_t size_; - // Reference class - template - friend class Map; -}; - -/*! \brief A specialization of small-sized hash map */ -class SmallMapNode : public MapNode, - public runtime::InplaceArrayBase { - private: - static constexpr uint64_t kInitSize = 2; - static constexpr uint64_t kMaxSize = 4; - - public: - using MapNode::iterator; - using MapNode::KVType; - - /*! \brief Defaults to the destructor of InplaceArrayBase */ - ~SmallMapNode() = default; - /*! - * \brief Count the number of times a key exists in the SmallMapNode - * \param key The indexing key - * \return The result, 0 or 1 - */ - size_t count(const key_type& key) const { return find(key).index < size_; } - /*! - * \brief Index value associated with a key, throw exception if the key does not exist - * \param key The indexing key - * \return The const reference to the value - */ - const mapped_type& at(const key_type& key) const { - iterator itr = find(key); - ICHECK(itr.index < size_) << "IndexError: key is not in Map"; - return itr->second; - } - /*! - * \brief Index value associated with a key, throw exception if the key does not exist - * \param key The indexing key - * \return The mutable reference to the value - */ - mapped_type& at(const key_type& key) { - iterator itr = find(key); - ICHECK(itr.index < size_) << "IndexError: key is not in Map"; - return itr->second; - } - /*! \return begin iterator */ - iterator begin() const { return iterator(0, this); } - /*! \return end iterator */ - iterator end() const { return iterator(size_, this); } - /*! - * \brief Index value associated with a key - * \param key The indexing key - * \return The iterator of the entry associated with the key, end iterator if not exists - */ - iterator find(const key_type& key) const { - KVType* ptr = static_cast(AddressOf(0)); - for (uint64_t i = 0; i < size_; ++i, ++ptr) { - if (ObjectEqual()(ptr->first, key)) { - return iterator(i, this); - } - } - return iterator(size_, this); - } - /*! - * \brief Erase the entry associated with the iterator - * \param position The iterator - */ - void erase(const iterator& position) { Erase(position.index); } - - private: - /*! - * \brief Remove a position in SmallMapNode - * \param index The position to be removed - */ - void Erase(const uint64_t index) { - if (index >= size_) { - return; - } - KVType* begin = static_cast(AddressOf(0)); - KVType* last = begin + (size_ - 1); - if (index + 1 == size_) { - last->first.ObjectRef::~ObjectRef(); - last->second.ObjectRef::~ObjectRef(); - } else { - *(begin + index) = std::move(*last); - } - size_ -= 1; - } - /*! - * \brief Create an empty container - * \param n Number of empty slots - * \return The object created - */ - static ObjectPtr Empty(uint64_t n = kInitSize) { - using ::tvm::runtime::make_inplace_array_object; - ObjectPtr p = make_inplace_array_object(n); - p->size_ = 0; - p->slots_ = n; - return p; - } - /*! - * \brief Create an empty container initialized with a given range - * \param n Number of empty slots - * \param first begin of iterator - * \param last end of iterator - * \tparam IterType The type of iterator - * \return The object created - */ - template - static ObjectPtr CreateFromRange(uint64_t n, IterType first, IterType last) { - ObjectPtr p = Empty(n); - KVType* ptr = static_cast(p->AddressOf(0)); - for (; first != last; ++first, ++p->size_) { - new (ptr++) KVType(*first); - } - return p; - } - /*! - * \brief Create an empty container with elements copying from another SmallMapNode - * \param from The source container - * \return The object created - */ - static ObjectPtr CopyFrom(SmallMapNode* from) { - KVType* first = static_cast(from->AddressOf(0)); - KVType* last = first + from->size_; - return CreateFromRange(from->size_, first, last); - } - /*! - * \brief InsertMaybeReHash an entry into the given hash map - * \param kv The entry to be inserted - * \param map The pointer to the map, can be changed if re-hashing happens - */ - static void InsertMaybeReHash(const KVType& kv, ObjectPtr* map) { - SmallMapNode* map_node = static_cast(map->get()); - iterator itr = map_node->find(kv.first); - if (itr.index < map_node->size_) { - itr->second = kv.second; - return; - } - if (map_node->size_ < map_node->slots_) { - KVType* ptr = static_cast(map_node->AddressOf(map_node->size_)); - new (ptr) KVType(kv); - ++map_node->size_; - return; - } - uint64_t next_size = std::max(map_node->slots_ * 2, uint64_t(kInitSize)); - next_size = std::min(next_size, uint64_t(kMaxSize)); - ICHECK_GT(next_size, map_node->slots_); - ObjectPtr new_map = CreateFromRange(next_size, map_node->begin(), map_node->end()); - InsertMaybeReHash(kv, &new_map); - *map = std::move(new_map); - } - /*! - * \brief Increment the pointer - * \param index The pointer to be incremented - * \return The increased pointer - */ - uint64_t IncItr(uint64_t index) const { return index + 1 < size_ ? index + 1 : size_; } - /*! - * \brief Decrement the pointer - * \param index The pointer to be decremented - * \return The decreased pointer - */ - uint64_t DecItr(uint64_t index) const { return index > 0 ? index - 1 : size_; } - /*! - * \brief De-reference the pointer - * \param index The pointer to be dereferenced - * \return The result - */ - KVType* DeRefItr(uint64_t index) const { return static_cast(AddressOf(index)); } - /*! \brief A size function used by InplaceArrayBase */ - uint64_t GetSize() const { return size_; } - - protected: - friend class MapNode; - friend class DenseMapNode; - friend class runtime::InplaceArrayBase; -}; - -/*! \brief A specialization of hash map that implements the idea of array-based hash map. - * Another reference implementation can be found [1]. - * - * A. Overview - * - * DenseMapNode did several improvements over traditional separate chaining hash, - * in terms of cache locality, memory footprints and data organization. - * - * A1. Implicit linked list. For better cache locality, instead of using linked list - * explicitly for each bucket, we store list data into a single array that spans contiguously - * in memory, and then carefully design access patterns to make sure most of them fall into - * a single cache line. - * - * A2. 1-byte metadata. There is only 1 byte overhead for each slot in the array to indexing and - * traversal. This can be divided in 3 parts. - * 1) Reserved code: (0b11111111)_2 indicates a slot is empty; (0b11111110)_2 indicates protected, - * which means the slot is empty but not allowed to be written. - * 2) If not empty or protected, the highest bit is used to indicate whether data in the slot is - * head of a linked list. - * 3) The rest 7 bits are used as the "next pointer" (i.e. pointer to the next element). On 64-bit - * architecture, an ordinary pointer can take up to 8 bytes, which is not acceptable overhead when - * dealing with 16-byte ObjectRef pairs. Based on a commonly noticed fact that the lists are - * relatively short (length <= 3) in hash maps, we follow [1]'s idea that only allows the pointer to - * be one of the 126 possible values, i.e. if the next element of i-th slot is (i + x)-th element, - * then x must be one of the 126 pre-defined values. - * - * A3. Data blocking. We organize the array in the way that every 16 elements forms a data block. - * The 16-byte metadata of those 16 elements are stored together, followed by the real data, i.e. - * 16 key-value pairs. - * - * B. Implementation details - * - * B1. Power-of-2 table size and Fibonacci Hashing. We use power-of-two as table size to avoid - * modulo for more efficient arithmetics. To make the hash-to-slot mapping distribute more evenly, - * we use the Fibonacci Hashing [2] trick. - * - * B2. Traverse a linked list in the array. - * 1) List head. Assume Fibonacci Hashing maps a given key to slot i, if metadata at slot i - * indicates that it is list head, then we found the head; otherwise the list is empty. No probing - * is done in this procedure. 2) Next element. To find the next element of a non-empty slot i, we - * look at the last 7 bits of the metadata at slot i. If they are all zeros, then it is the end of - * list; otherwise, we know that the next element is (i + candidates[the-last-7-bits]). - * - * B3. InsertMaybeReHash an element. Following B2, we first traverse the linked list to see if this - * element is in the linked list, and if not, we put it at the end by probing the next empty - * position in one of the 126 candidate positions. If the linked list does not even exist, but the - * slot for list head has been occupied by another linked list, we should find this intruder another - * place. - * - * B4. Quadratic probing with triangle numbers. In open address hashing, it is provable that probing - * with triangle numbers can traverse power-of-2-sized table [3]. In our algorithm, we follow the - * suggestion in [1] that also use triangle numbers for "next pointer" as well as sparing for list - * head. - * - * [1] https://github.com/skarupke/flat_hash_map - * [2] https://programmingpraxis.com/2018/06/19/fibonacci-hash/ - * [3] https://fgiesen.wordpress.com/2015/02/22/triangular-numbers-mod-2n/ - */ -class DenseMapNode : public MapNode { - private: - /*! \brief The number of elements in a memory block */ - static constexpr int kBlockCap = 16; - /*! \brief Maximum load factor of the hash map */ - static constexpr double kMaxLoadFactor = 0.99; - /*! \brief Binary representation of the metadata of an empty slot */ - static constexpr uint8_t kEmptySlot = uint8_t(0b11111111); - /*! \brief Binary representation of the metadata of a protected slot */ - static constexpr uint8_t kProtectedSlot = uint8_t(0b11111110); - /*! \brief Number of probing choices available */ - static constexpr int kNumJumpDists = 126; - /*! \brief Head of the implicit linked list */ - struct ListNode; - /*! \brief POD type of a block of memory */ - struct Block { - uint8_t bytes[kBlockCap + kBlockCap * sizeof(KVType)]; - }; - static_assert(sizeof(Block) == kBlockCap * (sizeof(KVType) + 1), "sizeof(Block) incorrect"); - static_assert(std::is_standard_layout::value, "Block is not standard layout"); - - public: - using MapNode::iterator; - - /*! - * \brief Destroy the DenseMapNode - */ - ~DenseMapNode() { this->Reset(); } - /*! \return The number of elements of the key */ - size_t count(const key_type& key) const { return !Search(key).IsNone(); } - /*! - * \brief Index value associated with a key, throw exception if the key does not exist - * \param key The indexing key - * \return The const reference to the value - */ - const mapped_type& at(const key_type& key) const { return At(key); } - /*! - * \brief Index value associated with a key, throw exception if the key does not exist - * \param key The indexing key - * \return The mutable reference to the value - */ - mapped_type& at(const key_type& key) { return At(key); } - /*! - * \brief Index value associated with a key - * \param key The indexing key - * \return The iterator of the entry associated with the key, end iterator if not exists - */ - iterator find(const key_type& key) const { - ListNode node = Search(key); - return node.IsNone() ? end() : iterator(node.index, this); - } - /*! - * \brief Erase the entry associated with the iterator - * \param position The iterator - */ - void erase(const iterator& position) { - uint64_t index = position.index; - if (position.self != nullptr && index <= this->slots_) { - Erase(ListNode(index, this)); - } - } - /*! \return begin iterator */ - iterator begin() const { - if (slots_ == 0) { - return iterator(0, this); - } - for (uint64_t index = 0; index <= slots_; ++index) { - if (!ListNode(index, this).IsEmpty()) { - return iterator(index, this); - } - } - return iterator(slots_ + 1, this); - } - /*! \return end iterator */ - iterator end() const { return slots_ == 0 ? iterator(0, this) : iterator(slots_ + 1, this); } - - private: - /*! - * \brief Search for the given key - * \param key The key - * \return ListNode that associated with the key - */ - ListNode Search(const key_type& key) const { - if (this->size_ == 0) { - return ListNode(); - } - for (ListNode iter = GetListHead(ObjectHash()(key)); !iter.IsNone(); iter.MoveToNext(this)) { - if (ObjectEqual()(key, iter.Key())) { - return iter; - } - } - return ListNode(); - } - /*! - * \brief Search for the given key, throw exception if not exists - * \param key The key - * \return ListNode that associated with the key - */ - mapped_type& At(const key_type& key) const { - ListNode iter = Search(key); - ICHECK(!iter.IsNone()) << "IndexError: key is not in Map"; - return iter.Val(); - } - /*! - * \brief Try to insert a key, or do nothing if already exists - * \param key The indexing key - * \param result The linked-list entry found or just constructed - * \return A boolean, indicating if actual insertion happens - */ - bool TryInsert(const key_type& key, ListNode* result) { - if (slots_ == 0) { - return false; - } - // required that `iter` to be the head of a linked list through which we can iterator - ListNode iter = IndexFromHash(ObjectHash()(key)); - // `iter` can be: 1) empty; 2) body of an irrelevant list; 3) head of the relevant list - // Case 1: empty - if (iter.IsEmpty()) { - iter.NewHead(KVType(key, ObjectRef(nullptr))); - this->size_ += 1; - *result = iter; - return true; - } - // Case 2: body of an irrelevant list - if (!iter.IsHead()) { - // we move the elements around and construct the single-element linked list - return IsFull() ? false : TrySpareListHead(iter, key, result); - } - // Case 3: head of the relevant list - // we iterate through the linked list until the end - // make sure `iter` is the previous element of `next` - ListNode next = iter; - do { - // find equal item, do not insert - if (ObjectEqual()(key, next.Key())) { - *result = next; - return true; - } - // make sure `iter` is the previous element of `next` - iter = next; - } while (next.MoveToNext(this)); - // `iter` is the tail of the linked list - // always check capacity before insertion - if (IsFull()) { - return false; - } - // find the next empty slot - uint8_t jump; - if (!iter.GetNextEmpty(this, &jump, result)) { - return false; - } - result->NewTail(KVType(key, ObjectRef(nullptr))); - // link `iter` to `empty`, and move forward - iter.SetJump(jump); - this->size_ += 1; - return true; - } - /*! - * \brief Spare an entry to be the head of a linked list. - * As described in B3, during insertion, it is possible that the entire linked list does not - * exist, but the slot of its head has been occupied by other linked lists. In this case, we need - * to spare the slot by moving away the elements to another valid empty one to make insertion - * possible. - * \param target The given entry to be spared - * \param key The indexing key - * \param result The linked-list entry constructed as the head - * \return A boolean, if actual insertion happens - */ - bool TrySpareListHead(ListNode target, const key_type& key, ListNode* result) { - // `target` is not the head of the linked list - // move the original item of `target` (if any) - // and construct new item on the position `target` - // To make `target` empty, we - // 1) find `w` the previous element of `target` in the linked list - // 2) copy the linked list starting from `r = target` - // 3) paste them after `w` - // read from the linked list after `r` - ListNode r = target; - // write to the tail of `w` - ListNode w = target.FindPrev(this); - // after `target` is moved, we disallow writing to the slot - bool is_first = true; - uint8_t r_meta, jump; - ListNode empty; - do { - // `jump` describes how `w` is jumped to `empty` - // rehash if there is no empty space after `w` - if (!w.GetNextEmpty(this, &jump, &empty)) { - return false; - } - // move `r` to `empty` - empty.NewTail(std::move(r.Data())); - // clear the metadata of `r` - r_meta = r.Meta(); - if (is_first) { - is_first = false; - r.SetProtected(); - } else { - r.SetEmpty(); - } - // link `w` to `empty`, and move forward - w.SetJump(jump); - w = empty; - // move `r` forward as well - } while (r.MoveToNext(this, r_meta)); - // finally we have done moving the linked list - // fill data_ into `target` - target.NewHead(KVType(key, ObjectRef(nullptr))); - this->size_ += 1; - *result = target; - return true; - } - /*! - * \brief Remove a ListNode - * \param iter The node to be removed - */ - void Erase(const ListNode& iter) { - this->size_ -= 1; - if (!iter.HasNext()) { - // `iter` is the last - if (!iter.IsHead()) { - // cut the link if there is any - iter.FindPrev(this).SetJump(0); - } - iter.Data().KVType::~KVType(); - iter.SetEmpty(); - } else { - ListNode last = iter, prev = iter; - for (last.MoveToNext(this); last.HasNext(); prev = last, last.MoveToNext(this)) { - } - iter.Data() = std::move(last.Data()); - last.SetEmpty(); - prev.SetJump(0); - } - } - /*! \brief Clear the container to empty, release all entries and memory acquired */ - void Reset() { - uint64_t n_blocks = CalcNumBlocks(this->slots_); - for (uint64_t bi = 0; bi < n_blocks; ++bi) { - uint8_t* meta_ptr = data_[bi].bytes; - KVType* data_ptr = reinterpret_cast(data_[bi].bytes + kBlockCap); - for (int j = 0; j < kBlockCap; ++j, ++meta_ptr, ++data_ptr) { - uint8_t& meta = *meta_ptr; - if (meta != uint8_t(kProtectedSlot) && meta != uint8_t(kEmptySlot)) { - meta = uint8_t(kEmptySlot); - data_ptr->KVType::~KVType(); - } - } - } - ReleaseMemory(); - } - /*! \brief Release the memory acquired by the container without deleting its entries stored inside - */ - void ReleaseMemory() { - delete[] data_; - data_ = nullptr; - slots_ = 0; - size_ = 0; - fib_shift_ = 63; - } - /*! - * \brief Create an empty container - * \param fib_shift The fib shift provided - * \param n_slots Number of slots required, should be power-of-two - * \return The object created - */ - static ObjectPtr Empty(uint32_t fib_shift, uint64_t n_slots) { - ICHECK_GT(n_slots, uint64_t(SmallMapNode::kMaxSize)); - ObjectPtr p = make_object(); - uint64_t n_blocks = CalcNumBlocks(n_slots - 1); - Block* block = p->data_ = new Block[n_blocks]; - p->slots_ = n_slots - 1; - p->size_ = 0; - p->fib_shift_ = fib_shift; - for (uint64_t i = 0; i < n_blocks; ++i, ++block) { - std::fill(block->bytes, block->bytes + kBlockCap, uint8_t(kEmptySlot)); - } - return p; - } - /*! - * \brief Create an empty container with elements copying from another DenseMapNode - * \param from The source container - * \return The object created - */ - static ObjectPtr CopyFrom(DenseMapNode* from) { - ObjectPtr p = make_object(); - uint64_t n_blocks = CalcNumBlocks(from->slots_); - p->data_ = new Block[n_blocks]; - p->slots_ = from->slots_; - p->size_ = from->size_; - p->fib_shift_ = from->fib_shift_; - for (uint64_t bi = 0; bi < n_blocks; ++bi) { - uint8_t* meta_ptr_from = from->data_[bi].bytes; - KVType* data_ptr_from = reinterpret_cast(from->data_[bi].bytes + kBlockCap); - uint8_t* meta_ptr_to = p->data_[bi].bytes; - KVType* data_ptr_to = reinterpret_cast(p->data_[bi].bytes + kBlockCap); - for (int j = 0; j < kBlockCap; - ++j, ++meta_ptr_from, ++data_ptr_from, ++meta_ptr_to, ++data_ptr_to) { - uint8_t& meta = *meta_ptr_to = *meta_ptr_from; - ICHECK(meta != kProtectedSlot); - if (meta != uint8_t(kEmptySlot)) { - new (data_ptr_to) KVType(*data_ptr_from); - } - } - } - return p; - } - /*! - * \brief InsertMaybeReHash an entry into the given hash map - * \param kv The entry to be inserted - * \param map The pointer to the map, can be changed if re-hashing happens - */ - static void InsertMaybeReHash(const KVType& kv, ObjectPtr* map) { - DenseMapNode* map_node = static_cast(map->get()); - ListNode iter; - // Try to insert. If succeed, we simply return - if (map_node->TryInsert(kv.first, &iter)) { - iter.Val() = kv.second; - return; - } - ICHECK_GT(map_node->slots_, uint64_t(SmallMapNode::kMaxSize)); - // Otherwise, start rehash - ObjectPtr p = Empty(map_node->fib_shift_ - 1, map_node->slots_ * 2 + 2); - // Insert the given `kv` into the new hash map - InsertMaybeReHash(kv, &p); - uint64_t n_blocks = CalcNumBlocks(map_node->slots_); - // Then Insert data from the original block. - for (uint64_t bi = 0; bi < n_blocks; ++bi) { - uint8_t* meta_ptr = map_node->data_[bi].bytes; - KVType* data_ptr = reinterpret_cast(map_node->data_[bi].bytes + kBlockCap); - for (int j = 0; j < kBlockCap; ++j, ++meta_ptr, ++data_ptr) { - uint8_t& meta = *meta_ptr; - if (meta != uint8_t(kProtectedSlot) && meta != uint8_t(kEmptySlot)) { - meta = uint8_t(kEmptySlot); - KVType kv = std::move(*data_ptr); - InsertMaybeReHash(kv, &p); - } - } - } - map_node->ReleaseMemory(); - *map = p; - } - /*! - * \brief Check whether the hash table is full - * \return A boolean indicating whether hash table is full - */ - bool IsFull() const { return size_ + 1 > (slots_ + 1) * kMaxLoadFactor; } - /*! - * \brief Increment the pointer - * \param index The pointer to be incremented - * \return The increased pointer - */ - uint64_t IncItr(uint64_t index) const { - for (++index; index <= slots_; ++index) { - if (!ListNode(index, this).IsEmpty()) { - return index; - } - } - return slots_ + 1; - } - /*! - * \brief Decrement the pointer - * \param index The pointer to be decremented - * \return The decreased pointer - */ - uint64_t DecItr(uint64_t index) const { - while (index != 0) { - index -= 1; - if (!ListNode(index, this).IsEmpty()) { - return index; - } - } - return slots_ + 1; - } - /*! - * \brief De-reference the pointer - * \param index The pointer to be dereferenced - * \return The result - */ - KVType* DeRefItr(uint64_t index) const { return &ListNode(index, this).Data(); } - /*! \brief Construct from hash code */ - ListNode IndexFromHash(uint64_t hash_value) const { - return ListNode(FibHash(hash_value, fib_shift_), this); - } - /*! \brief Construct from hash code if the position is head of list */ - ListNode GetListHead(uint64_t hash_value) const { - ListNode node = IndexFromHash(hash_value); - return node.IsHead() ? node : ListNode(); - } - /*! \brief Construct the number of blocks in the hash table */ - static uint64_t CalcNumBlocks(uint64_t n_slots_m1) { - uint64_t n_slots = n_slots_m1 > 0 ? n_slots_m1 + 1 : 0; - return (n_slots + kBlockCap - 1) / kBlockCap; - } - /*! - * \brief Calculate the power-of-2 table size given the lower-bound of required capacity. - * \param cap The lower-bound of the required capacity - * \param fib_shift The result shift for Fibonacci Hashing - * \param n_slots The result number of slots - */ - static void CalcTableSize(uint64_t cap, uint32_t* fib_shift, uint64_t* n_slots) { - uint32_t shift = 64; - uint64_t slots = 1; - for (uint64_t c = cap; c; c >>= 1) { - shift -= 1; - slots <<= 1; - } - ICHECK_GT(slots, cap); - if (slots < cap * 2) { - *fib_shift = shift - 1; - *n_slots = slots << 1; - } else { - *fib_shift = shift; - *n_slots = slots; - } - } - /*! - * \brief Fibonacci Hashing, maps a hash code to an index in a power-of-2-sized table. - * See also: https://programmingpraxis.com/2018/06/19/fibonacci-hash/. - * \param hash_value The raw hash value - * \param fib_shift The shift in Fibonacci Hashing - * \return An index calculated using Fibonacci Hashing - */ - static uint64_t FibHash(uint64_t hash_value, uint32_t fib_shift) { - constexpr uint64_t coeff = 11400714819323198485ull; - return (coeff * hash_value) >> fib_shift; - } - /*! \brief The implicit in-place linked list used to index a chain */ - struct ListNode { - /*! \brief Construct None */ - ListNode() : index(0), block(nullptr) {} - /*! \brief Construct from position */ - ListNode(uint64_t index, const DenseMapNode* self) - : index(index), block(self->data_ + (index / kBlockCap)) {} - /*! \brief Metadata on the entry */ - uint8_t& Meta() const { return *(block->bytes + index % kBlockCap); } - /*! \brief Data on the entry */ - KVType& Data() const { - return *(reinterpret_cast(block->bytes + kBlockCap + - (index % kBlockCap) * sizeof(KVType))); - } - /*! \brief Key on the entry */ - key_type& Key() const { return Data().first; } - /*! \brief Value on the entry */ - mapped_type& Val() const { return Data().second; } - /*! \brief If the entry is head of linked list */ - bool IsHead() const { return (Meta() & 0b10000000) == 0b00000000; } - /*! \brief If the entry is none */ - bool IsNone() const { return block == nullptr; } - /*! \brief If the entry is empty slot */ - bool IsEmpty() const { return Meta() == uint8_t(kEmptySlot); } - /*! \brief If the entry is protected slot */ - bool IsProtected() const { return Meta() == uint8_t(kProtectedSlot); } - /*! \brief Set the entry to be empty */ - void SetEmpty() const { Meta() = uint8_t(kEmptySlot); } - /*! \brief Set the entry to be protected */ - void SetProtected() const { Meta() = uint8_t(kProtectedSlot); } - /*! \brief Set the entry's jump to its next entry */ - void SetJump(uint8_t jump) const { (Meta() &= 0b10000000) |= jump; } - /*! \brief Construct a head of linked list in-place */ - void NewHead(KVType v) const { - Meta() = 0b00000000; - new (&Data()) KVType(std::move(v)); - } - /*! \brief Construct a tail of linked list in-place */ - void NewTail(KVType v) const { - Meta() = 0b10000000; - new (&Data()) KVType(std::move(v)); - } - /*! \brief If the entry has next entry on the linked list */ - bool HasNext() const { return NextProbeLocation(Meta() & 0b01111111) != 0; } - /*! \brief Move the entry to the next entry on the linked list */ - bool MoveToNext(const DenseMapNode* self, uint8_t meta) { - uint64_t offset = NextProbeLocation(meta & 0b01111111); - if (offset == 0) { - index = 0; - block = nullptr; - return false; - } - index = (index + offset) & (self->slots_); - block = self->data_ + (index / kBlockCap); - return true; - } - /*! \brief Move the entry to the next entry on the linked list */ - bool MoveToNext(const DenseMapNode* self) { return MoveToNext(self, Meta()); } - /*! \brief Get the previous entry on the linked list */ - ListNode FindPrev(const DenseMapNode* self) const { - // start from the head of the linked list, which must exist - ListNode next = self->IndexFromHash(ObjectHash()(Key())); - // `prev` is always the previous item of `next` - ListNode prev = next; - for (next.MoveToNext(self); index != next.index; prev = next, next.MoveToNext(self)) { - } - return prev; - } - /*! \brief Get the next empty jump */ - bool GetNextEmpty(const DenseMapNode* self, uint8_t* jump, ListNode* result) const { - for (uint8_t idx = 1; idx < kNumJumpDists; ++idx) { - ListNode candidate((index + NextProbeLocation(idx)) & (self->slots_), self); - if (candidate.IsEmpty()) { - *jump = idx; - *result = candidate; - return true; - } - } - return false; - } - /*! \brief Index on the real array */ - uint64_t index; - /*! \brief Pointer to the actual block */ - Block* block; - }; - - protected: - /*! \brief fib shift in Fibonacci Hashing */ - uint32_t fib_shift_; - /*! \brief array of data blocks */ - Block* data_; - static uint64_t NextProbeLocation(size_t index) { - /* clang-format off */ - /*! \brief Candidates of probing distance */ - static const uint64_t kNextProbeLocation[kNumJumpDists] { - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - // Quadratic probing with triangle numbers. See also: - // 1) https://en.wikipedia.org/wiki/Quadratic_probing - // 2) https://fgiesen.wordpress.com/2015/02/22/triangular-numbers-mod-2n/ - // 3) https://github.com/skarupke/flat_hash_map - 21, 28, 36, 45, 55, 66, 78, 91, 105, 120, - 136, 153, 171, 190, 210, 231, 253, 276, 300, 325, - 351, 378, 406, 435, 465, 496, 528, 561, 595, 630, - 666, 703, 741, 780, 820, 861, 903, 946, 990, 1035, - 1081, 1128, 1176, 1225, 1275, 1326, 1378, 1431, 1485, 1540, - 1596, 1653, 1711, 1770, 1830, 1891, 1953, 2016, 2080, 2145, - 2211, 2278, 2346, 2415, 2485, 2556, 2628, - // larger triangle numbers - 8515, 19110, 42778, 96141, 216153, - 486591, 1092981, 2458653, 5532801, 12442566, - 27993903, 62983476, 141717030, 318844378, 717352503, - 1614057336, 3631522476, 8170957530, 18384510628, 41364789378, - 93070452520, 209408356380, 471168559170, 1060128894105, 2385289465695, - 5366898840628, 12075518705635, 27169915244790, 61132312065111, 137547689707000, - 309482283181501, 696335127828753, 1566753995631385, 3525196511162271, 7931691992677701, - 17846306936293605, 40154190677507445, 90346928918121501, 203280589587557251, - 457381325854679626, 1029107982097042876, 2315492959180353330, 5209859154120846435, - }; - /* clang-format on */ - return kNextProbeLocation[index]; - } - friend class MapNode; -}; - -#define TVM_DISPATCH_MAP(base, var, body) \ - { \ - using TSmall = SmallMapNode*; \ - using TDense = DenseMapNode*; \ - uint64_t slots = base->slots_; \ - if (slots <= SmallMapNode::kMaxSize) { \ - TSmall var = static_cast(base); \ - body; \ - } else { \ - TDense var = static_cast(base); \ - body; \ - } \ - } - -#define TVM_DISPATCH_MAP_CONST(base, var, body) \ - { \ - using TSmall = const SmallMapNode*; \ - using TDense = const DenseMapNode*; \ - uint64_t slots = base->slots_; \ - if (slots <= SmallMapNode::kMaxSize) { \ - TSmall var = static_cast(base); \ - body; \ - } else { \ - TDense var = static_cast(base); \ - body; \ - } \ - } - -inline MapNode::iterator::pointer MapNode::iterator::operator->() const { - TVM_MAP_FAIL_IF_CHANGED() - TVM_DISPATCH_MAP_CONST(self, p, { return p->DeRefItr(index); }); -} - -inline MapNode::iterator& MapNode::iterator::operator++() { - TVM_MAP_FAIL_IF_CHANGED() - TVM_DISPATCH_MAP_CONST(self, p, { - index = p->IncItr(index); - return *this; - }); -} - -inline MapNode::iterator& MapNode::iterator::operator--() { - TVM_MAP_FAIL_IF_CHANGED() - TVM_DISPATCH_MAP_CONST(self, p, { - index = p->DecItr(index); - return *this; - }); -} - -inline size_t MapNode::count(const key_type& key) const { - TVM_DISPATCH_MAP_CONST(this, p, { return p->count(key); }); -} - -inline const MapNode::mapped_type& MapNode::at(const MapNode::key_type& key) const { - TVM_DISPATCH_MAP_CONST(this, p, { return p->at(key); }); -} - -inline MapNode::mapped_type& MapNode::at(const MapNode::key_type& key) { - TVM_DISPATCH_MAP(this, p, { return p->at(key); }); -} - -inline MapNode::iterator MapNode::begin() const { - TVM_DISPATCH_MAP_CONST(this, p, { return p->begin(); }); -} - -inline MapNode::iterator MapNode::end() const { - TVM_DISPATCH_MAP_CONST(this, p, { return p->end(); }); -} - -inline MapNode::iterator MapNode::find(const MapNode::key_type& key) const { - TVM_DISPATCH_MAP_CONST(this, p, { return p->find(key); }); -} - -inline void MapNode::erase(const MapNode::iterator& position) { - TVM_DISPATCH_MAP(this, p, { return p->erase(position); }); -} - -#undef TVM_DISPATCH_MAP -#undef TVM_DISPATCH_MAP_CONST - -inline ObjectPtr MapNode::Empty() { return SmallMapNode::Empty(); } - -inline ObjectPtr MapNode::CopyFrom(MapNode* from) { - if (from->slots_ <= SmallMapNode::kMaxSize) { - return SmallMapNode::CopyFrom(static_cast(from)); - } else { - return DenseMapNode::CopyFrom(static_cast(from)); - } -} - -template -inline ObjectPtr MapNode::CreateFromRange(IterType first, IterType last) { - int64_t _cap = std::distance(first, last); - if (_cap < 0) { - return SmallMapNode::Empty(); - } - uint64_t cap = static_cast(_cap); - if (cap < SmallMapNode::kMaxSize) { - return SmallMapNode::CreateFromRange(cap, first, last); - } - uint32_t fib_shift; - uint64_t n_slots; - DenseMapNode::CalcTableSize(cap, &fib_shift, &n_slots); - ObjectPtr obj = DenseMapNode::Empty(fib_shift, n_slots); - for (; first != last; ++first) { - KVType kv(*first); - DenseMapNode::InsertMaybeReHash(kv, &obj); - } - return obj; -} - -inline void MapNode::InsertMaybeReHash(const KVType& kv, ObjectPtr* map) { - constexpr uint64_t kSmallMapMaxSize = SmallMapNode::kMaxSize; - MapNode* base = static_cast(map->get()); -#if TVM_DEBUG_WITH_ABI_CHANGE - base->state_marker++; -#endif // TVM_DEBUG_WITH_ABI_CHANGE - if (base->slots_ < kSmallMapMaxSize) { - SmallMapNode::InsertMaybeReHash(kv, map); - } else if (base->slots_ == kSmallMapMaxSize) { - if (base->size_ < base->slots_) { - SmallMapNode::InsertMaybeReHash(kv, map); - } else { - ObjectPtr new_map = MapNode::CreateFromRange(base->begin(), base->end()); - DenseMapNode::InsertMaybeReHash(kv, &new_map); - *map = std::move(new_map); - } - } else { - DenseMapNode::InsertMaybeReHash(kv, map); - } -} - -template <> -inline ObjectPtr make_object<>() = delete; - -#endif - -/*! - * \brief Map container of NodeRef->NodeRef in DSL graph. - * Map implements copy on write semantics, which means map is mutable - * but copy will happen when array is referenced in more than two places. - * - * operator[] only provide const acces, use Set to mutate the content. - * \tparam K The key NodeRef type. - * \tparam V The value NodeRef type. - */ -template ::value>::type, - typename = typename std::enable_if::value>::type> -class Map : public ObjectRef { - public: - using key_type = K; - using mapped_type = V; - class iterator; - /*! - * \brief default constructor - */ - Map() { data_ = MapNode::Empty(); } - /*! - * \brief move constructor - * \param other source - */ - Map(Map&& other) { data_ = std::move(other.data_); } - /*! - * \brief copy constructor - * \param other source - */ - Map(const Map& other) : ObjectRef(other.data_) {} - /*! - * \brief copy assign operator - * \param other The source of assignment - * \return reference to self. - */ - Map& operator=(Map&& other) { - data_ = std::move(other.data_); - return *this; - } - /*! - * \brief move assign operator - * \param other The source of assignment - * \return reference to self. - */ - Map& operator=(const Map& other) { - data_ = other.data_; - return *this; - } - /*! - * \brief constructor from pointer - * \param n the container pointer - */ - explicit Map(ObjectPtr n) : ObjectRef(n) {} - /*! - * \brief constructor from iterator - * \param begin begin of iterator - * \param end end of iterator - * \tparam IterType The type of iterator - */ - template - Map(IterType begin, IterType end) { - data_ = MapNode::CreateFromRange(begin, end); - } - /*! - * \brief constructor from initializer list - * \param init The initalizer list - */ - Map(std::initializer_list> init) { - data_ = MapNode::CreateFromRange(init.begin(), init.end()); - } - /*! - * \brief constructor from unordered_map - * \param init The unordered_map - */ - template - Map(const std::unordered_map& init) { // NOLINT(*) - data_ = MapNode::CreateFromRange(init.begin(), init.end()); - } - /*! - * \brief Read element from map. - * \param key The key - * \return the corresonding element. - */ - const V at(const K& key) const { return DowncastNoCheck(GetMapNode()->at(key)); } - /*! - * \brief Read element from map. - * \param key The key - * \return the corresonding element. - */ - const V operator[](const K& key) const { return this->at(key); } - /*! \return The size of the array */ - size_t size() const { - MapNode* n = GetMapNode(); - return n == nullptr ? 0 : n->size(); - } - /*! \return The number of elements of the key */ - size_t count(const K& key) const { - MapNode* n = GetMapNode(); - return n == nullptr ? 0 : GetMapNode()->count(key); - } - /*! \return whether array is empty */ - bool empty() const { return size() == 0; } - /*! \brief Release reference to all the elements */ - void clear() { - MapNode* n = GetMapNode(); - if (n != nullptr) { - data_ = MapNode::Empty(); - } - } - /*! - * \brief set the Map. - * \param key The index key. - * \param value The value to be setted. - */ - void Set(const K& key, const V& value) { - CopyOnWrite(); - MapNode::InsertMaybeReHash(MapNode::KVType(key, value), &data_); - } - /*! \return begin iterator */ - iterator begin() const { return iterator(GetMapNode()->begin()); } - /*! \return end iterator */ - iterator end() const { return iterator(GetMapNode()->end()); } - /*! \return find the key and returns the associated iterator */ - iterator find(const K& key) const { return iterator(GetMapNode()->find(key)); } - /*! \return The value associated with the key, NullOpt if not found */ - Optional Get(const K& key) const { - MapNode::iterator iter = GetMapNode()->find(key); - if (iter == GetMapNode()->end()) { - return NullOptType{}; - } - return DowncastNoCheck(iter->second); - } - void erase(const K& key) { CopyOnWrite()->erase(key); } - - /*! - * \brief copy on write semantics - * Do nothing if current handle is the unique copy of the array. - * Otherwise make a new copy of the array to ensure the current handle - * hold a unique copy. - * - * \return Handle to the internal node container(which guarantees to be unique) - */ - MapNode* CopyOnWrite() { - if (data_.get() == nullptr) { - data_ = MapNode::Empty(); - } else if (!data_.unique()) { - data_ = MapNode::CopyFrom(GetMapNode()); - } - return GetMapNode(); - } - /*! \brief specify container node */ - using ContainerType = MapNode; - - /*! \brief Iterator of the hash map */ - class iterator { - public: - using iterator_category = std::bidirectional_iterator_tag; - using difference_type = int64_t; - using value_type = const std::pair; - using pointer = value_type*; - using reference = value_type; - - iterator() : itr() {} - - /*! \brief Compare iterators */ - bool operator==(const iterator& other) const { return itr == other.itr; } - /*! \brief Compare iterators */ - bool operator!=(const iterator& other) const { return itr != other.itr; } - /*! \brief De-reference iterators is not allowed */ - pointer operator->() const = delete; - /*! \brief De-reference iterators */ - reference operator*() const { - auto& kv = *itr; - return std::make_pair(DowncastNoCheck(kv.first), DowncastNoCheck(kv.second)); - } - /*! \brief Prefix self increment, e.g. ++iter */ - iterator& operator++() { - ++itr; - return *this; - } - /*! \brief Suffix self increment */ - iterator operator++(int) { - iterator copy = *this; - ++(*this); - return copy; - } - - private: - iterator(const MapNode::iterator& itr) // NOLINT(*) - : itr(itr) {} - - template - friend class Map; - - MapNode::iterator itr; - }; - - private: - /*! \brief Return data_ as type of pointer of MapNode */ - MapNode* GetMapNode() const { return static_cast(data_.get()); } -}; - -/*! - * \brief Merge two Maps. - * \param lhs the first Map to merge. - * \param rhs the second Map to merge. - * @return The merged Array. Original Maps are kept unchanged. - */ -template ::value>::type, - typename = typename std::enable_if::value>::type> -inline Map Merge(Map lhs, const Map& rhs) { - for (const auto& p : rhs) { - lhs.Set(p.first, p.second); - } - return std::move(lhs); -} +using tvm::ffi::Map; } // namespace runtime // expose the functions to the root namespace. -using runtime::Map; -using runtime::MapNode; +using tvm::ffi::Map; +using tvm::ffi::MapObj; } // namespace tvm - #endif // TVM_RUNTIME_CONTAINER_MAP_H_ diff --git a/include/tvm/runtime/container/optional.h b/include/tvm/runtime/container/optional.h index 024986a8e037..4dc3b680de7a 100644 --- a/include/tvm/runtime/container/optional.h +++ b/include/tvm/runtime/container/optional.h @@ -18,155 +18,22 @@ */ /*! - * \file tvm/runtime/container/optional.h - * \brief Runtime Optional container types. + * \file tvm/runtime/container/string.h + * \brief Runtime String container types. */ #ifndef TVM_RUNTIME_CONTAINER_OPTIONAL_H_ #define TVM_RUNTIME_CONTAINER_OPTIONAL_H_ -#include - -#include "./base.h" +#include namespace tvm { namespace runtime { -/*! \brief Helper to represent nullptr for optional. */ -struct NullOptType {}; - -/*! - * \brief Optional container that to represent to a Nullable variant of T. - * \tparam T The original ObjectRef. - * - * \code - * - * Optional opt0 = nullptr; - * Optional opt1 = String("xyz"); - * ICHECK(opt0 == nullptr); - * ICHECK(opt1 == "xyz"); - * - * \endcode - */ -template -class Optional : public ObjectRef { - public: - using ContainerType = typename T::ContainerType; - static_assert(std::is_base_of::value, "Optional is only defined for ObjectRef."); - // default constructors. - Optional() = default; - Optional(const Optional&) = default; - Optional(Optional&&) = default; - Optional& operator=(const Optional&) = default; - Optional& operator=(Optional&&) = default; - /*! - * \brief Construct from an ObjectPtr - * whose type already matches the ContainerType. - * \param ptr - */ - explicit Optional(ObjectPtr ptr) : ObjectRef(ptr) {} - /*! \brief Nullopt handling */ - Optional(NullOptType) {} // NOLINT(*) - // nullptr handling. - // disallow implicit conversion as 0 can be implicitly converted to nullptr_t - explicit Optional(std::nullptr_t) {} - Optional& operator=(std::nullptr_t) { - data_ = nullptr; - return *this; - } - // normal value handling. - Optional(T other) // NOLINT(*) - : ObjectRef(std::move(other)) {} - Optional& operator=(T other) { - ObjectRef::operator=(std::move(other)); - return *this; - } - // delete the int constructor - // since Optional(0) is ambiguious - // 0 can be implicitly casted to nullptr_t - explicit Optional(int val) = delete; - Optional& operator=(int val) = delete; - /*! - * \return A not-null container value in the optional. - * \note This function performs not-null checking. - */ - T value() const { - ICHECK(data_ != nullptr); - return T(data_); - } - /*! - * \return The internal object pointer with container type of T. - * \note This function do not perform not-null checking. - */ - const ContainerType* get() const { return static_cast(data_.get()); } - /*! - * \return The contained value if the Optional is not null - * otherwise return the default_value. - */ - T value_or(T default_value) const { return data_ != nullptr ? T(data_) : default_value; } - - /*! \return Whether the container is not nullptr.*/ - explicit operator bool() const { return *this != nullptr; } - // operator overloadings - bool operator==(std::nullptr_t) const { return data_ == nullptr; } - bool operator!=(std::nullptr_t) const { return data_ != nullptr; } - auto operator==(const Optional& other) const { - // support case where sub-class returns a symbolic ref type. - using RetType = decltype(value() == other.value()); - if (same_as(other)) return RetType(true); - if (*this != nullptr && other != nullptr) { - return value() == other.value(); - } else { - // one of them is nullptr. - return RetType(false); - } - } - auto operator!=(const Optional& other) const { - // support case where sub-class returns a symbolic ref type. - using RetType = decltype(value() != other.value()); - if (same_as(other)) return RetType(false); - if (*this != nullptr && other != nullptr) { - return value() != other.value(); - } else { - // one of them is nullptr. - return RetType(true); - } - } - auto operator==(const T& other) const { - using RetType = decltype(value() == other); - if (same_as(other)) return RetType(true); - if (*this != nullptr) return value() == other; - return RetType(false); - } - auto operator!=(const T& other) const { return !(*this == other); } - template - auto operator==(const U& other) const { - using RetType = decltype(value() == other); - if (*this == nullptr) return RetType(false); - return value() == other; - } - template - auto operator!=(const U& other) const { - using RetType = decltype(value() != other); - if (*this == nullptr) return RetType(true); - return value() != other; - } - static constexpr bool _type_is_nullable = true; -}; - -template -inline Optional ObjectRef::as() const { - if (auto* ptr = this->as()) { - return GetRef(ptr); - } else { - return NullOptType{}; - } -} - +using tvm::ffi::Optional; } // namespace runtime -// expose the functions to the root namespace. -using runtime::Optional; -constexpr runtime::NullOptType NullOpt{}; +// expose class to root namespace +using tvm::ffi::Optional; +constexpr inline auto NullOpt = std::nullopt; } // namespace tvm - #endif // TVM_RUNTIME_CONTAINER_OPTIONAL_H_ diff --git a/include/tvm/runtime/container/shape_tuple.h b/include/tvm/runtime/container/shape_tuple.h index 1fb6248cc2f1..c7a96b6623a6 100644 --- a/include/tvm/runtime/container/shape_tuple.h +++ b/include/tvm/runtime/container/shape_tuple.h @@ -24,6 +24,8 @@ #ifndef TVM_RUNTIME_CONTAINER_SHAPE_TUPLE_H_ #define TVM_RUNTIME_CONTAINER_SHAPE_TUPLE_H_ +#include + #include #include #include @@ -33,167 +35,9 @@ namespace tvm { namespace runtime { -/*! \brief An object representing a shape tuple. */ -class ShapeTupleObj : public Object { - public: - /*! \brief The type of shape index element. */ - using index_type = int64_t; - /*! \brief The pointer to shape tuple data. */ - index_type* data; - /*! \brief The size of the shape tuple object. */ - uint64_t size; - - /*! \brief Get "numel", meaning the number of elements of an array if the array has this shape */ - index_type Product() const; - - static constexpr const uint32_t _type_index = runtime::TypeIndex::kRuntimeShapeTuple; - static constexpr const char* _type_key = "runtime.ShapeTuple"; - TVM_DECLARE_FINAL_OBJECT_INFO(ShapeTupleObj, Object); - - private: - /*! \brief ShapeTuple object which is moved from std::vector container. */ - class FromStd; - - friend class ShapeTuple; -}; - -/*! \brief An object representing shape tuple moved from std::vector. */ -class ShapeTupleObj::FromStd : public ShapeTupleObj { - public: - /*! \brief The type of shape index element. */ - using index_type = ShapeTupleObj::index_type; - /*! - * \brief Construct a new FromStd object - * - * \param other The moved/copied std::vector object - * - * \note If user passes const reference, it will trigger copy. If it's rvalue, - * it will be moved into other. - */ - explicit FromStd(std::vector other) : data_container{other} {} - - private: - /*! \brief Container that holds the memory. */ - std::vector data_container; - - friend class ShapeTuple; -}; - -/*! - * \brief Reference to shape tuple objects. - */ -class ShapeTuple : public ObjectRef { - public: - /*! \brief The type of shape index element. */ - using index_type = ShapeTupleObj::index_type; - - /*! - * \brief Construct an empty shape tuple. - */ - ShapeTuple() : ShapeTuple(std::vector()) {} - - /*! - * \brief Constructor from iterator - * \param begin begin of iterator - * \param end end of iterator - * \tparam IterType The type of iterator - */ - template - ShapeTuple(IterType begin, IterType end) : ShapeTuple(std::vector(begin, end)) {} - - /*! - * \brief constructor from initializer list - * \param shape The initializer list - */ - ShapeTuple(std::initializer_list shape) : ShapeTuple(shape.begin(), shape.end()) {} - - /*! - * \brief Construct a new ShapeTuple object - * - * \param shape The moved/copied std::vector object - * - * \note If user passes const reference, it will trigger copy. If it's rvalue, - * it will be moved into other. - */ - ShapeTuple(std::vector shape); // NOLINT(*) - - /*! - * \brief Return the data pointer - * - * \return const index_type* data pointer - */ - const index_type* data() const { return get()->data; } - - /*! - * \brief Return the size of the shape tuple - * - * \return size_t shape tuple size - */ - size_t size() const { return get()->size; } - - /*! - * \brief Immutably read i-th element from the shape tuple. - * \param idx The index - * \return the i-th element. - */ - index_type operator[](size_t idx) const { - ICHECK(idx < this->size()) << "IndexError: indexing " << idx << " on an array of size " - << this->size(); - return this->data()[idx]; - } - - /*! - * \brief Immutably read i-th element from the shape tuple. - * \param idx The index - * \return the i-th element. - */ - index_type at(size_t idx) const { return this->operator[](idx); } - - /*! \return Whether shape tuple is empty */ - bool empty() const { return size() == 0; } - - /*! \return The first element of the shape tuple */ - index_type front() const { return this->at(0); } - - /*! \return The last element of the shape tuple */ - index_type back() const { return this->at(this->size() - 1); } - - /*! \return begin iterator */ - const index_type* begin() const { return get()->data; } - - /*! \return end iterator */ - const index_type* end() const { return (get()->data + size()); } - - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ShapeTuple, ObjectRef, ShapeTupleObj); -}; - -inline ShapeTuple::ShapeTuple(std::vector shape) { - auto ptr = make_object(std::move(shape)); - ptr->size = ptr->data_container.size(); - ptr->data = ptr->data_container.data(); - data_ = std::move(ptr); -} - -inline ShapeTupleObj::index_type ShapeTupleObj::Product() const { - index_type numel = 1; - for (int i = 0, n = this->size; i < n; ++i) { - numel *= this->data[i]; - } - return numel; -} - -inline std::ostream& operator<<(std::ostream& os, const ShapeTuple& shape) { - os << '['; - for (size_t i = 0; i < shape->size; ++i) { - if (i != 0) { - os << ", "; - } - os << shape->data[i]; - } - os << ']'; - return os; -} - +using Shape = tvm::ffi::Shape; +using ShapeTuple = tvm::ffi::Shape; +using ShapeTupleObj = tvm::ffi::ShapeObj; using IntTuple = ShapeTuple; using IntTupleObj = ShapeTupleObj; diff --git a/include/tvm/runtime/container/string.h b/include/tvm/runtime/container/string.h index a7be84de23f9..d55d9dbd7960 100644 --- a/include/tvm/runtime/container/string.h +++ b/include/tvm/runtime/container/string.h @@ -24,521 +24,18 @@ #ifndef TVM_RUNTIME_CONTAINER_STRING_H_ #define TVM_RUNTIME_CONTAINER_STRING_H_ -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include +#include namespace tvm { namespace runtime { -// Forward declare TVMArgValue -class TVMArgValue; - -/*! \brief An object representing string. It's POD type. */ -class StringObj : public Object { - public: - /*! \brief The pointer to string data. */ - const char* data; - - /*! \brief The length of the string object. */ - uint64_t size; - - static constexpr const uint32_t _type_index = TypeIndex::kRuntimeString; - static constexpr const char* _type_key = "runtime.String"; - TVM_DECLARE_FINAL_OBJECT_INFO(StringObj, Object); - - private: - /*! \brief String object which is moved from std::string container. */ - class FromStd; - - friend class String; -}; - -/*! - * \brief Reference to string objects. - * - * \code - * - * // Example to create runtime String reference object from std::string - * std::string s = "hello world"; - * - * // You can create the reference from existing std::string - * String ref{std::move(s)}; - * - * // You can rebind the reference to another string. - * ref = std::string{"hello world2"}; - * - * // You can use the reference as hash map key - * std::unordered_map m; - * m[ref] = 1; - * - * // You can compare the reference object with other string objects - * assert(ref == "hello world", true); - * - * // You can convert the reference to std::string again - * string s2 = (string)ref; - * - * \endcode - */ -class String : public ObjectRef { - public: - /*! - * \brief Construct an empty string. - */ - String() : String(std::string()) {} - /*! - * \brief Construct a new String object - * - * \param other The moved/copied std::string object - * - * \note If user passes const reference, it will trigger copy. If it's rvalue, - * it will be moved into other. - */ - String(std::string other); // NOLINT(*) - - /*! - * \brief Construct a new String object - * - * \param other a char array. - */ - String(const char* other) // NOLINT(*) - : String(std::string(other)) {} - - /*! - * \brief Construct a new null object - */ - String(std::nullptr_t) // NOLINT(*) - : ObjectRef(nullptr) {} - - /*! - * \brief Change the value the reference object points to. - * - * \param other The value for the new String - * - */ - inline String& operator=(std::string other); - - /*! - * \brief Change the value the reference object points to. - * - * \param other The value for the new String - */ - inline String& operator=(const char* other); - - /*! - * \brief Compares this String object to other - * - * \param other The String to compare with. - * - * \return zero if both char sequences compare equal. negative if this appear - * before other, positive otherwise. - */ - int compare(const String& other) const { - return memncmp(data(), other.data(), size(), other.size()); - } - - /*! - * \brief Compares this String object to other - * - * \param other The string to compare with. - * - * \return zero if both char sequences compare equal. negative if this appear - * before other, positive otherwise. - */ - int compare(const std::string& other) const { - return memncmp(data(), other.data(), size(), other.size()); - } - - /*! - * \brief Compares this to other - * - * \param other The character array to compare with. - * - * \return zero if both char sequences compare equal. negative if this appear - * before other, positive otherwise. - */ - int compare(const char* other) const { - return memncmp(data(), other, size(), std::strlen(other)); - } - - /*! - * \brief Returns a pointer to the char array in the string. - * - * \return const char* - */ - const char* c_str() const { return get()->data; } - - /*! - * \brief Return the length of the string - * - * \return size_t string length - */ - size_t size() const { - const auto* ptr = get(); - return ptr->size; - } - - /*! - * \brief Return the length of the string - * - * \return size_t string length - */ - size_t length() const { return size(); } - - /*! - * \brief Retun if the string is empty - * - * \return true if empty, false otherwise. - */ - bool empty() const { return size() == 0; } - - /*! - * \brief Read an element. - * \param pos The position at which to read the character. - * - * \return The char at position - */ - char at(size_t pos) const { - if (pos < size()) { - return data()[pos]; - } else { - throw std::out_of_range("tvm::String index out of bounds"); - } - } - - /*! - * \brief Return the data pointer - * - * \return const char* data pointer - */ - const char* data() const { return get()->data; } - - /*! - * \brief Convert String to an std::string object - * - * \return std::string - */ - operator std::string() const { return std::string{get()->data, size()}; } - - /*! - * \brief Check if a TVMArgValue can be converted to String, i.e. it can be std::string or String - * \param val The value to be checked - * \return A boolean indicating if val can be converted to String - */ - inline static bool CanConvertFrom(const TVMArgValue& val); - - /*! - * \brief Hash the binary bytes - * \param data The data pointer - * \param size The size of the bytes. - * \return the hash value. - */ - static uint64_t StableHashBytes(const char* data, size_t size) { - const constexpr uint64_t kMultiplier = 1099511628211ULL; - const constexpr uint64_t kMod = 2147483647ULL; - union Union { - uint8_t a[8]; - uint64_t b; - } u; - static_assert(sizeof(Union) == sizeof(uint64_t), "sizeof(Union) != sizeof(uint64_t)"); - const char* it = data; - const char* end = it + size; - uint64_t result = 0; - for (; it + 8 <= end; it += 8) { - if (DMLC_IO_NO_ENDIAN_SWAP) { - u.a[0] = it[0]; - u.a[1] = it[1]; - u.a[2] = it[2]; - u.a[3] = it[3]; - u.a[4] = it[4]; - u.a[5] = it[5]; - u.a[6] = it[6]; - u.a[7] = it[7]; - } else { - u.a[0] = it[7]; - u.a[1] = it[6]; - u.a[2] = it[5]; - u.a[3] = it[4]; - u.a[4] = it[3]; - u.a[5] = it[2]; - u.a[6] = it[1]; - u.a[7] = it[0]; - } - result = (result * kMultiplier + u.b) % kMod; - } - if (it < end) { - u.b = 0; - uint8_t* a = u.a; - if (it + 4 <= end) { - a[0] = it[0]; - a[1] = it[1]; - a[2] = it[2]; - a[3] = it[3]; - it += 4; - a += 4; - } - if (it + 2 <= end) { - a[0] = it[0]; - a[1] = it[1]; - it += 2; - a += 2; - } - if (it + 1 <= end) { - a[0] = it[0]; - it += 1; - a += 1; - } - if (!DMLC_IO_NO_ENDIAN_SWAP) { - std::swap(u.a[0], u.a[7]); - std::swap(u.a[1], u.a[6]); - std::swap(u.a[2], u.a[5]); - std::swap(u.a[3], u.a[4]); - } - result = (result * kMultiplier + u.b) % kMod; - } - return result; - } - - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(String, ObjectRef, StringObj); - - private: - /*! - * \brief Compare two char sequence - * - * \param lhs Pointers to the char array to compare - * \param rhs Pointers to the char array to compare - * \param lhs_count Length of the char array to compare - * \param rhs_count Length of the char array to compare - * \return int zero if both char sequences compare equal. negative if this - * appear before other, positive otherwise. - */ - static int memncmp(const char* lhs, const char* rhs, size_t lhs_count, size_t rhs_count); - - /*! - * \brief Concatenate two char sequences - * - * \param lhs Pointers to the lhs char array - * \param lhs_size The size of the lhs char array - * \param rhs Pointers to the rhs char array - * \param rhs_size The size of the rhs char array - * - * \return The concatenated char sequence - */ - static String Concat(const char* lhs, size_t lhs_size, const char* rhs, size_t rhs_size) { - std::string ret(lhs, lhs_size); - ret.append(rhs, rhs_size); - return String(ret); - } - - // Overload + operator - friend String operator+(const String& lhs, const String& rhs); - friend String operator+(const String& lhs, const std::string& rhs); - friend String operator+(const std::string& lhs, const String& rhs); - friend String operator+(const String& lhs, const char* rhs); - friend String operator+(const char* lhs, const String& rhs); - - friend struct tvm::runtime::ObjectEqual; -}; - -/*! \brief An object representing string moved from std::string. */ -class StringObj::FromStd : public StringObj { - public: - /*! - * \brief Construct a new FromStd object - * - * \param other The moved/copied std::string object - * - * \note If user passes const reference, it will trigger copy. If it's rvalue, - * it will be moved into other. - */ - explicit FromStd(std::string other) : data_container{other} {} - - private: - /*! \brief Container that holds the memory. */ - std::string data_container; - - friend class String; -}; - -inline String::String(std::string other) { - auto ptr = make_object(std::move(other)); - ptr->size = ptr->data_container.size(); - ptr->data = ptr->data_container.data(); - data_ = std::move(ptr); -} - -inline String& String::operator=(std::string other) { - String replace{std::move(other)}; - data_.swap(replace.data_); - return *this; -} - -inline String& String::operator=(const char* other) { return operator=(std::string(other)); } +using tvm::ffi::String; +using tvm::ffi::StringObj; -inline String operator+(const String& lhs, const String& rhs) { - size_t lhs_size = lhs.size(); - size_t rhs_size = rhs.size(); - return String::Concat(lhs.data(), lhs_size, rhs.data(), rhs_size); -} - -inline String operator+(const String& lhs, const std::string& rhs) { - size_t lhs_size = lhs.size(); - size_t rhs_size = rhs.size(); - return String::Concat(lhs.data(), lhs_size, rhs.data(), rhs_size); -} - -inline String operator+(const std::string& lhs, const String& rhs) { - size_t lhs_size = lhs.size(); - size_t rhs_size = rhs.size(); - return String::Concat(lhs.data(), lhs_size, rhs.data(), rhs_size); -} - -inline String operator+(const char* lhs, const String& rhs) { - size_t lhs_size = std::strlen(lhs); - size_t rhs_size = rhs.size(); - return String::Concat(lhs, lhs_size, rhs.data(), rhs_size); -} - -inline String operator+(const String& lhs, const char* rhs) { - size_t lhs_size = lhs.size(); - size_t rhs_size = std::strlen(rhs); - return String::Concat(lhs.data(), lhs_size, rhs, rhs_size); -} - -// Overload < operator -inline bool operator<(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) < 0; } - -inline bool operator<(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) > 0; } - -inline bool operator<(const String& lhs, const String& rhs) { return lhs.compare(rhs) < 0; } - -inline bool operator<(const String& lhs, const char* rhs) { return lhs.compare(rhs) < 0; } - -inline bool operator<(const char* lhs, const String& rhs) { return rhs.compare(lhs) > 0; } - -// Overload > operator -inline bool operator>(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) > 0; } - -inline bool operator>(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) < 0; } - -inline bool operator>(const String& lhs, const String& rhs) { return lhs.compare(rhs) > 0; } - -inline bool operator>(const String& lhs, const char* rhs) { return lhs.compare(rhs) > 0; } - -inline bool operator>(const char* lhs, const String& rhs) { return rhs.compare(lhs) < 0; } - -// Overload <= operator -inline bool operator<=(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) <= 0; } - -inline bool operator<=(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) >= 0; } - -inline bool operator<=(const String& lhs, const String& rhs) { return lhs.compare(rhs) <= 0; } - -inline bool operator<=(const String& lhs, const char* rhs) { return lhs.compare(rhs) <= 0; } - -inline bool operator<=(const char* lhs, const String& rhs) { return rhs.compare(lhs) >= 0; } - -// Overload >= operator -inline bool operator>=(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) >= 0; } - -inline bool operator>=(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) <= 0; } - -inline bool operator>=(const String& lhs, const String& rhs) { return lhs.compare(rhs) >= 0; } - -inline bool operator>=(const String& lhs, const char* rhs) { return lhs.compare(rhs) >= 0; } - -inline bool operator>=(const char* lhs, const String& rhs) { return rhs.compare(rhs) <= 0; } - -// Overload == operator -inline bool operator==(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) == 0; } - -inline bool operator==(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) == 0; } - -inline bool operator==(const String& lhs, const String& rhs) { return lhs.compare(rhs) == 0; } - -inline bool operator==(const String& lhs, const char* rhs) { return lhs.compare(rhs) == 0; } - -inline bool operator==(const char* lhs, const String& rhs) { return rhs.compare(lhs) == 0; } - -// Overload != operator -inline bool operator!=(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) != 0; } - -inline bool operator!=(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) != 0; } - -inline bool operator!=(const String& lhs, const String& rhs) { return lhs.compare(rhs) != 0; } - -inline bool operator!=(const String& lhs, const char* rhs) { return lhs.compare(rhs) != 0; } - -inline bool operator!=(const char* lhs, const String& rhs) { return rhs.compare(lhs) != 0; } - -inline std::ostream& operator<<(std::ostream& out, const String& input) { - out.write(input.data(), input.size()); - return out; -} - -inline int String::memncmp(const char* lhs, const char* rhs, size_t lhs_count, size_t rhs_count) { - if (lhs == rhs && lhs_count == rhs_count) return 0; - - for (size_t i = 0; i < lhs_count && i < rhs_count; ++i) { - if (lhs[i] < rhs[i]) return -1; - if (lhs[i] > rhs[i]) return 1; - } - if (lhs_count < rhs_count) { - return -1; - } else if (lhs_count > rhs_count) { - return 1; - } else { - return 0; - } -} - -inline size_t ObjectHash::operator()(const ObjectRef& a) const { - if (const auto* str = a.as()) { - return String::StableHashBytes(str->data, str->size); - } - return ObjectPtrHash()(a); -} - -inline bool ObjectEqual::operator()(const ObjectRef& a, const ObjectRef& b) const { - if (a.same_as(b)) { - return true; - } - if (const auto* str_a = a.as()) { - if (const auto* str_b = b.as()) { - return String::memncmp(str_a->data, str_b->data, str_a->size, str_b->size) == 0; - } - } - return false; -} } // namespace runtime -// expose the functions to the root namespace. -using runtime::String; -using runtime::StringObj; -} // namespace tvm - -namespace std { - -template <> -struct hash<::tvm::runtime::String> { - std::size_t operator()(const ::tvm::runtime::String& str) const { - return ::tvm::runtime::String::StableHashBytes(str.data(), str.size()); - } -}; -} // namespace std +using tvm::ffi::String; +using tvm::ffi::StringObj; +} // namespace tvm #endif // TVM_RUNTIME_CONTAINER_STRING_H_ diff --git a/include/tvm/runtime/container/variant.h b/include/tvm/runtime/container/variant.h index e8defa4e6fee..9ba9a987115b 100644 --- a/include/tvm/runtime/container/variant.h +++ b/include/tvm/runtime/container/variant.h @@ -19,105 +19,21 @@ /*! * \file tvm/runtime/container/variant.h - * \brief Runtime Variant container types. + * \brief Runtime variant container. */ #ifndef TVM_RUNTIME_CONTAINER_VARIANT_H_ #define TVM_RUNTIME_CONTAINER_VARIANT_H_ -#include - -#include -#include -#include +#include namespace tvm { namespace runtime { -namespace detail { -template -constexpr bool parent_is_base_of_any = false; - -template -constexpr bool parent_is_base_of_any> = - ((std::is_base_of_v && !std::is_same_v) || ...); - -/* \brief Utility to check if any parent is a base class of any child - * - * The type-checking in Variant relies on all types being from - * independent types, such that `Object::IsInstance` is sufficient to - * determine which variant is populated. - * - * For example, suppose the illegal `Variant` - * were allowed (e.g. to represent either the defintion of a variable - * or the usage of a variable). If a function returned - * `tir::PrimExpr`, it could result in either variant being filled, as - * the underlying type at runtime could be a `tir::Var`. This - * behavior is different from `std::variant`, which determines the - * active variant based solely on the compile-time type, and could - * produce very unexpected results if the variants have different - * semantic interpretations. - */ -template -static constexpr bool any_parent_is_base_of_any_child = false; - -template -static constexpr bool any_parent_is_base_of_any_child, ChildTuple> = - (parent_is_base_of_any || ...); -} // namespace detail - -template -class Variant : public ObjectRef { - static constexpr bool all_inherit_from_objectref = (std::is_base_of_v && ...); - static_assert(all_inherit_from_objectref, - "All types used in Variant<...> must inherit from ObjectRef"); - - static constexpr bool a_variant_inherits_from_another_variant = - detail::any_parent_is_base_of_any_child, std::tuple>; - static_assert(!a_variant_inherits_from_another_variant, - "Due to implementation limitations, " - "no type stored in a tvm::runtime::Variant " - "may be a subclass of any other type " - "stored in the same variant."); - - public: - /* \brief Helper utility to check if the type is part of the variant */ - template - static constexpr bool is_variant = (std::is_base_of_v || ...); - - /* \brief Helper utility for SFINAE if the type is part of the variant */ - template - using enable_if_variant = std::enable_if_t>; - - template > - Variant(T value) : ObjectRef(std::move(value)) {} // NOLINT(*) - - template > - Variant& operator=(T value) { - ObjectRef::operator=(std::move(value)); - return *this; - } - - // These functions would normally be declared with the - // TVM_DEFINE_OBJECT_REF_METHODS macro. However, we need additional - // type-checking inside the ObjectPtr constructor. - using ContainerType = Object; - Variant() : ObjectRef() {} - explicit Variant(ObjectPtr node) : ObjectRef(node) { - CHECK(node == nullptr || (node->IsInstance() || ...)) - << "Variant<" - << static_cast( - (std::stringstream() << ... << V::ContainerType::_type_key)) - .str() - << "> cannot hold an object of type " << node->GetTypeKey(); - } - TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(Variant); -}; +using tvm::ffi::Variant; } // namespace runtime -// expose the functions to the root namespace. -using runtime::Variant; - +// expose class to root namespace +using tvm::ffi::Variant; } // namespace tvm - #endif // TVM_RUNTIME_CONTAINER_VARIANT_H_ diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h index 06de90eb7cb2..9418a0c902d4 100644 --- a/include/tvm/runtime/data_type.h +++ b/include/tvm/runtime/data_type.h @@ -24,6 +24,7 @@ #ifndef TVM_RUNTIME_DATA_TYPE_H_ #define TVM_RUNTIME_DATA_TYPE_H_ +#include #include #include @@ -54,7 +55,7 @@ class DataType { kInt = kDLInt, kUInt = kDLUInt, kFloat = kDLFloat, - kHandle = TVMArgTypeCode::kTVMOpaqueHandle, + kHandle = kDLOpaqueHandle, kBFloat = kDLBfloat, kFloat8_e3m4 = kDLFloat8_e3m4, kFloat8_e4m3 = kDLFloat8_e4m3, @@ -347,206 +348,54 @@ inline bool TypeEqual(DLDataType lhs, DLDataType rhs) { return lhs.code == rhs.code && lhs.bits == rhs.bits && lhs.lanes == rhs.lanes; } -/*! - * \brief Runtime utility for getting custom type name from code - * \param type_code Custom type code - * \return Custom type name - */ -TVM_DLL std::string GetCustomTypeName(uint8_t type_code); - -/*! - * \brief Runtime utility for checking whether custom type is registered - * \param type_code Custom type code - * \return Bool representing whether type is registered - */ -TVM_DLL bool GetCustomTypeRegistered(uint8_t type_code); +using ffi::DLDataTypeToString; +using ffi::StringToDLDataType; -/*! - * \brief Runtime utility for parsing string of the form "custom[]" - * \param s String to parse - * \param scan pointer to parsing pointer, which is scanning across s - * \return type code of custom type parsed - */ -TVM_DLL uint8_t ParseCustomDatatype(const std::string& s, const char** scan); +inline std::ostream& operator<<(std::ostream& os, const DataType& dtype) { // NOLINT(*) + return os << dtype.operator DLDataType(); +} +} // namespace runtime -/*! - * \brief Convert type code to its name - * \param type_code The type code . - * \return The name of type code. - */ -inline const char* DLDataTypeCode2Str(DLDataTypeCode type_code); +using DataType = runtime::DataType; -/*! - * \brief convert a string to TVM type. - * \param s The string to be converted. - * \return The corresponding tvm type. - */ -inline DLDataType String2DLDataType(std::string s); +namespace ffi { -/*! - * \brief convert a TVM type to string. - * \param t The type to be converted. - * \return The corresponding tvm type in string. - */ -inline std::string DLDataType2String(DLDataType t); +// runtime::DataType +template <> +struct TypeTraits : public TypeTraitsBase { + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIDataType; -// implementation details -inline const char* DLDataTypeCode2Str(DLDataTypeCode type_code) { - switch (static_cast(type_code)) { - case kDLInt: - return "int"; - case kDLUInt: - return "uint"; - case kDLFloat: - return "float"; - case DataType::kHandle: - return "handle"; - case kDLBfloat: - return "bfloat"; - case DataType::kFloat8_e4m3fn: - return "float8_e4m3fn"; - case DataType::kFloat8_e5m2: - return "float8_e5m2"; - case DataType::kFloat4_e2m1fn: - return "float4_e2m1fn"; - default: - LOG(FATAL) << "unknown type_code=" << static_cast(type_code); + static TVM_FFI_INLINE void CopyToAnyView(const runtime::DataType& src, TVMFFIAny* result) { + result->type_index = TypeIndex::kTVMFFIDataType; + result->v_dtype = src; } - throw; -} -inline std::ostream& operator<<(std::ostream& os, DLDataType t) { // NOLINT(*) - if (t.bits == 1 && t.lanes == 1 && t.code == kDLUInt) { - os << "bool"; - return os; - } - if (DataType(t).is_void()) { - return os << "void"; + static TVM_FFI_INLINE void MoveToAny(runtime::DataType src, TVMFFIAny* result) { + result->type_index = TypeIndex::kTVMFFIDataType; + result->v_dtype = src; } - if (t.code < DataType::kCustomBegin) { - os << DLDataTypeCode2Str(static_cast(t.code)); - } else { - os << "custom[" << GetCustomTypeName(t.code) << "]"; - } - if (t.code == kTVMOpaqueHandle) return os; - int16_t lanes = static_cast(t.lanes); - if (t.code != DataType::kFloat8_e4m3fn && t.code != DataType::kFloat8_e5m2 && - t.code != DataType::kFloat4_e2m1fn) { - os << static_cast(t.bits); - } - if (lanes > 1) { - os << 'x' << lanes; - } else if (lanes < -1) { - os << "xvscalex" << -lanes; - } - return os; -} - -inline std::ostream& operator<<(std::ostream& os, const DataType& dtype) { // NOLINT(*) - return os << dtype.operator DLDataType(); -} - -inline std::string DLDataType2String(DLDataType t) { - if (t.bits == 0) return ""; - std::ostringstream os; - os << t; - return os.str(); -} -inline DLDataType String2DLDataType(std::string s) { - DLDataType t; - // handle void type - if (s.length() == 0 || s == "void") { - t = DataType::Void(); - return t; - } - t.bits = 32; - t.lanes = 1; - const char* scan; - if (s.substr(0, 3) == "int") { - t.code = kDLInt; - scan = s.c_str() + 3; - } else if (s.substr(0, 4) == "uint") { - t.code = kDLUInt; - scan = s.c_str() + 4; - } else if (s.substr(0, 13) == "float4_e2m1fn") { - // Avoid being treated as "float" - t.code = DataType::kFloat4_e2m1fn; - t.bits = 4; - scan = s.c_str() + 13; - char* endpt = nullptr; - if (*scan == 'x') { - t.lanes = static_cast(strtoul(scan + 1, &endpt, 10)); - scan = endpt; - } - ICHECK(scan == s.c_str() + s.length()) << "unknown type " << s; - return t; - } else if (s.substr(0, 13) == "float8_e4m3fn") { - // Avoid being treated as "float" - t.code = DataType::kFloat8_e4m3fn; - t.bits = 8; - scan = s.c_str() + 13; - char* endpt = nullptr; - if (*scan == 'x') { - t.lanes = static_cast(strtoul(scan + 1, &endpt, 10)); - scan = endpt; - } - ICHECK(scan == s.c_str() + s.length()) << "unknown type " << s; - return t; - } else if (s.substr(0, 11) == "float8_e5m2") { - // Avoid being treated as "float" - t.code = DataType::kFloat8_e5m2; - t.bits = 8; - scan = s.c_str() + 11; - char* endpt = nullptr; - if (*scan == 'x') { - t.lanes = static_cast(strtoul(scan + 1, &endpt, 10)); - scan = endpt; + static TVM_FFI_INLINE std::optional TryConvertFromAnyView( + const TVMFFIAny* src) { + auto opt_dtype = TypeTraits::TryConvertFromAnyView(src); + if (opt_dtype) { + return runtime::DataType(opt_dtype.value()); } - ICHECK(scan == s.c_str() + s.length()) << "unknown type " << s; - return t; - } else if (s.substr(0, 5) == "float") { - t.code = kDLFloat; - scan = s.c_str() + 5; - } else if (s.substr(0, 6) == "handle") { - t.code = kTVMOpaqueHandle; - t.bits = 64; // handle uses 64 bit by default. - scan = s.c_str() + 6; - } else if (s == "bool") { - t.code = kDLUInt; - t.bits = 1; - t.lanes = 1; - return t; - } else if (s.substr(0, 6) == "bfloat") { - t.code = DataType::kBFloat; - t.bits = 16; - scan = s.c_str() + 6; - } else if (s.substr(0, 6) == "custom") { - t.code = ParseCustomDatatype(s, &scan); - } else { - scan = s.c_str(); - LOG(FATAL) << "unknown type " << s; + return std::nullopt; } - char* xdelim; // emulate sscanf("%ux%u", bits, lanes) - uint8_t bits = static_cast(strtoul(scan, &xdelim, 10)); - if (bits != 0) t.bits = bits; - int scalable_multiplier = 1; - if (strncmp(xdelim, "xvscale", 7) == 0) { - scalable_multiplier = -1; - xdelim += 7; - } - char* endpt = xdelim; - if (*xdelim == 'x') { - t.lanes = static_cast(scalable_multiplier * strtoul(xdelim + 1, &endpt, 10)); + + static TVM_FFI_INLINE bool CheckAnyStorage(const TVMFFIAny* src) { + return TypeTraits::CheckAnyStorage(src); } - ICHECK(endpt == s.c_str() + s.length()) << "unknown type " << s; - return t; -} -} // namespace runtime + static TVM_FFI_INLINE runtime::DataType CopyFromAnyStorageAfterCheck(const TVMFFIAny* src) { + return runtime::DataType(TypeTraits::CopyFromAnyStorageAfterCheck(src)); + } -using DataType = runtime::DataType; + static TVM_FFI_INLINE std::string TypeStr() { return ffi::StaticTypeKey::kTVMFFIDataType; } +}; +} // namespace ffi } // namespace tvm namespace std { diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h index 2564b73d1e94..a4b53eb79734 100644 --- a/include/tvm/runtime/device_api.h +++ b/include/tvm/runtime/device_api.h @@ -24,13 +24,18 @@ #ifndef TVM_RUNTIME_DEVICE_API_H_ #define TVM_RUNTIME_DEVICE_API_H_ +#include +#include #include -#include -#include +#include #include namespace tvm { + +// alias DLDevice +using Device = DLDevice; + namespace runtime { /*! * \brief the query type into GetAttr @@ -96,7 +101,7 @@ class TVM_DLL DeviceAPI { * \param rv The return value. * \sa DeviceAttrKind */ - virtual void GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) = 0; + virtual void GetAttr(Device dev, DeviceAttrKind kind, ffi::Any* rv) = 0; /*! * \brief Get the physical memory size required. @@ -104,7 +109,8 @@ class TVM_DLL DeviceAPI { * \param mem_scope the memory scope if any * \return the memory size. */ - virtual size_t GetDataSize(const DLTensor& arr, Optional mem_scope = NullOpt); + virtual size_t GetDataSize(const DLTensor& arr, + ffi::Optional mem_scope = std::nullopt); /*! * \brief Query the device for specified properties. @@ -112,7 +118,7 @@ class TVM_DLL DeviceAPI { * This is used to expand "-from_device=N" in the target string to * all properties that can be determined from that device. */ - virtual void GetTargetProperty(Device dev, const std::string& property, TVMRetValue* rv) {} + virtual void GetTargetProperty(Device dev, const std::string& property, ffi::Any* rv) {} /*! * \brief Allocate a data space on device. @@ -135,7 +141,7 @@ class TVM_DLL DeviceAPI { * \return The allocated device pointer. */ virtual void* AllocDataSpace(Device dev, int ndim, const int64_t* shape, DLDataType dtype, - Optional mem_scope = NullOpt); + ffi::Optional mem_scope = std::nullopt); /*! * \brief Free a data space on device. * \param dev The device device to perform operation. @@ -263,6 +269,47 @@ class TVM_DLL DeviceAPI { DLDataType type_hint, TVMStreamHandle stream); }; +/*! + * \brief The name of DLDeviceType. + * \param type The device type. + * \return the device name. + */ +inline const char* DLDeviceType2Str(int type) { + switch (type) { + case kDLCPU: + return "cpu"; + case kDLCUDA: + return "cuda"; + case kDLCUDAHost: + return "cuda_host"; + case kDLCUDAManaged: + return "cuda_managed"; + case kDLOpenCL: + return "opencl"; + case kDLVulkan: + return "vulkan"; + case kDLMetal: + return "metal"; + case kDLVPI: + return "vpi"; + case kDLROCM: + return "rocm"; + case kDLROCMHost: + return "rocm_host"; + case kDLExtDev: + return "ext_dev"; + case kDLOneAPI: + return "oneapi"; + case kDLWebGPU: + return "webgpu"; + case kDLHexagon: + return "hexagon"; + default: + LOG(FATAL) << "unknown type = " << type; + } + throw; +} + /*! \brief The device type bigger than this is RPC device */ constexpr int kRPCSessMask = 128; static_assert(kRPCSessMask >= TVMDeviceExtType_End); diff --git a/include/tvm/runtime/disco/session.h b/include/tvm/runtime/disco/session.h index 9c34f8a2af9e..32f148853073 100644 --- a/include/tvm/runtime/disco/session.h +++ b/include/tvm/runtime/disco/session.h @@ -148,7 +148,8 @@ class DRefObj : public Object { static constexpr const char* _type_key = "runtime.disco.DRef"; static constexpr const uint32_t _type_index = TypeIndex::kRuntimeDiscoDRef; - TVM_DECLARE_FINAL_OBJECT_INFO(DRefObj, Object); + static const constexpr bool _type_final = true; + TVM_FFI_DECLARE_STATIC_OBJECT_INFO(DRefObj, Object); /*! \brief The id of the register */ int64_t reg_id; @@ -242,7 +243,7 @@ class SessionObj : public Object { * \param value The value to be set. * \param worker_id The id of the worker to be set. */ - TVM_DLL virtual void DebugSetRegister(int64_t reg_id, TVMArgValue value, int worker_id) = 0; + TVM_DLL virtual void DebugSetRegister(int64_t reg_id, AnyView value, int worker_id) = 0; struct FFI; friend struct SessionObj::FFI; @@ -282,7 +283,7 @@ class Session : public ObjectRef { TVM_DLL static Session ProcessSession(int num_workers, int num_groups, String process_pool_creator, String entrypoint); - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Session, ObjectRef, SessionObj); + TVM_FFI_DEFINE_MUTABLE_OBJECT_REF_METHODS(Session, ObjectRef, SessionObj); }; /*! @@ -337,14 +338,13 @@ template DRef SessionObj::CallPacked(const DRef& func, Args&&... args) { constexpr int offset = 3; constexpr int kNumArgs = offset + sizeof...(Args); - TVMValue values[kNumArgs]; - int type_codes[kNumArgs]; - PackArgs(values, type_codes, - /*.0=*/static_cast(DiscoAction::kCallPacked), // action - /*.1=*/0, // reg_id, which will be updated by this->CallWithPacked - /*.2=*/func, // the function to be called - std::forward(args)...); - return this->CallWithPacked(TVMArgs(values, type_codes, kNumArgs)); + AnyView packed_args[kNumArgs]; + ffi::PackedArgs::Fill(packed_args, + /*.0=*/static_cast(DiscoAction::kCallPacked), // action + /*.1=*/0, // reg_id, which will be updated by this->CallWithPacked + /*.2=*/func, // the function to be called + std::forward(args)...); + return this->CallWithPacked(ffi::PackedArgs(packed_args, kNumArgs)); } } // namespace runtime diff --git a/include/tvm/runtime/logging.h b/include/tvm/runtime/logging.h index 440c7a4383c3..807c9dbf30bc 100644 --- a/include/tvm/runtime/logging.h +++ b/include/tvm/runtime/logging.h @@ -31,6 +31,7 @@ #include #include +#include #include #include @@ -186,39 +187,8 @@ namespace tvm { namespace runtime { -/*! - * \brief Generate a backtrace when called. - * \return A multiline string of the backtrace. There will be either one or two lines per frame. - */ -TVM_DLL std::string Backtrace(); - -/*! \brief Base error type for TVM. Wraps a string message. */ -class Error : public ::dmlc::Error { // for backwards compatibility - public: - /*! - * \brief Construct an error. - * \param s The message to be displayed with the error. - */ - explicit Error(const std::string& s) : ::dmlc::Error(s) {} -}; - -/*! - * \brief Error message already set in frontend env. - * - * This error can be thrown by EnvCheckSignals to indicate - * that there is an error set in the frontend environment(e.g. - * python interpreter). The TVM FFI should catch this error - * and return a proper code tell the frontend caller about - * this fact. - */ -class EnvErrorAlreadySet : public ::dmlc::Error { - public: - /*! - * \brief Construct an error. - * \param s The message to be displayed with the error. - */ - explicit EnvErrorAlreadySet(const std::string& s) : ::dmlc::Error(s) {} -}; +using ffi::EnvErrorAlreadySet; +using ffi::Error; /*! * \brief Error type for errors from CHECK, ICHECK, and LOG(FATAL). This error @@ -234,47 +204,41 @@ class InternalError : public Error { * \param time The time at which the error occurred. This should be in local time. * \param backtrace Backtrace from when the error occurred. */ - InternalError(std::string file, int lineno, std::string message, - std::time_t time = std::time(nullptr), std::string backtrace = Backtrace()) - : Error(""), - file_(file), - lineno_(lineno), - message_(message), - time_(time), - backtrace_(backtrace) { - std::ostringstream s; - // XXX: Do not change this format, otherwise all error handling in python will break (because it - // parses the message to reconstruct the error type). - // TODO(tkonolige): Convert errors to Objects, so we can avoid the mess of formatting/parsing - // error messages correctly. - s << "[" << std::put_time(std::localtime(&time), "%H:%M:%S") << "] " << file << ":" << lineno - << ": " << message << std::endl; - if (backtrace.size() > 0) { - s << backtrace << std::endl; + InternalError(std::string file, int lineno, std::string message) + : Error(DetectKind(message), DetectMessage(message), + TVMFFITraceback(file.c_str(), lineno, "")) {} + + private: + // try to detect the kind of error from the message when the error type + // is folded into the text message + static std::string DetectKind(const std::string& message) { + size_t pos = message.find("Error:"); + if (pos != std::string::npos) { + size_t end = pos + 6; + size_t begin = pos; + for (; begin != 0 && message[begin - 1] != ' '; --begin) { + } + return message.substr(begin, end - begin - 1); + } else { + return "InternalError"; } - full_message_ = s.str(); } - /*! \return The file in which the error occurred. */ - const std::string& file() const { return file_; } - /*! \return The message associated with this error. */ - const std::string& message() const { return message_; } - /*! \return Formatted error message including file, linenumber, backtrace, and message. */ - const std::string& full_message() const { return full_message_; } - /*! \return The backtrace from where this error occurred. */ - const std::string& backtrace() const { return backtrace_; } - /*! \return The time at which this error occurred. */ - const std::time_t& time() const { return time_; } - /*! \return The line number at which this error occurred. */ - int lineno() const { return lineno_; } - virtual const char* what() const noexcept { return full_message_.c_str(); } - private: - std::string file_; - int lineno_; - std::string message_; - std::time_t time_; - std::string backtrace_; - std::string full_message_; // holds the full error string + static std::string DetectMessage(const std::string& message) { + size_t pos = message.find("Error:"); + if (pos != std::string::npos) { + size_t end = pos + 6; + size_t begin = pos; + for (; begin != 0 && message[begin - 1] != ' '; --begin) { + } + if (end < message.size() && message[end] == ' ') { + end += 1; + } + return message.substr(0, begin) + message.substr(end); + } else { + return message; + } + } }; /*! \brief Internal implementation */ diff --git a/include/tvm/runtime/memory.h b/include/tvm/runtime/memory.h index 1199c420f212..96ca0c8d696a 100644 --- a/include/tvm/runtime/memory.h +++ b/include/tvm/runtime/memory.h @@ -17,190 +17,20 @@ * under the License. */ /*! - * \file tvm/runtime/memory.h - * \brief Runtime memory management. + * \file tvm/runtime/object.h + * \brief A managed object in the TVM runtime. */ #ifndef TVM_RUNTIME_MEMORY_H_ #define TVM_RUNTIME_MEMORY_H_ -#include - -#include -#include -#include +#include namespace tvm { namespace runtime { -/*! - * \brief Allocate an object using default allocator. - * \param args arguments to the constructor. - * \tparam T the node type. - * \return The ObjectPtr to the allocated object. - */ -template -inline ObjectPtr make_object(Args&&... args); - -// Detail implementations after this -// -// The current design allows swapping the -// allocator pattern when necessary. -// -// Possible future allocator optimizations: -// - Arena allocator that gives ownership of memory to arena (deleter_= nullptr) -// - Thread-local object pools: one pool per size and alignment requirement. -// - Can specialize by type of object to give the specific allocator to each object. - -/*! - * \brief Base class of object allocators that implements make. - * Use curiously recurring template pattern. - * - * \tparam Derived The derived class. - */ -template -class ObjAllocatorBase { - public: - /*! - * \brief Make a new object using the allocator. - * \tparam T The type to be allocated. - * \tparam Args The constructor signature. - * \param args The arguments. - */ - template - inline ObjectPtr make_object(Args&&... args) { - using Handler = typename Derived::template Handler; - static_assert(std::is_base_of::value, "make can only be used to create Object"); - T* ptr = Handler::New(static_cast(this), std::forward(args)...); - ptr->type_index_ = T::RuntimeTypeIndex(); - ptr->deleter_ = Handler::Deleter(); - return ObjectPtr(ptr); - } - - /*! - * \tparam ArrayType The type to be allocated. - * \tparam ElemType The type of array element. - * \tparam Args The constructor signature. - * \param num_elems The number of array elements. - * \param args The arguments. - */ - template - inline ObjectPtr make_inplace_array(size_t num_elems, Args&&... args) { - using Handler = typename Derived::template ArrayHandler; - static_assert(std::is_base_of::value, - "make_inplace_array can only be used to create Object"); - ArrayType* ptr = - Handler::New(static_cast(this), num_elems, std::forward(args)...); - ptr->type_index_ = ArrayType::RuntimeTypeIndex(); - ptr->deleter_ = Handler::Deleter(); - return ObjectPtr(ptr); - } -}; - -// Simple allocator that uses new/delete. -class SimpleObjAllocator : public ObjAllocatorBase { - public: - template - class Handler { - public: - using StorageType = typename std::aligned_storage::type; - - template - static T* New(SimpleObjAllocator*, Args&&... args) { - // NOTE: the first argument is not needed for SimpleObjAllocator - // It is reserved for special allocators that needs to recycle - // the object to itself (e.g. in the case of object pool). - // - // In the case of an object pool, an allocator needs to create - // a special chunk memory that hides reference to the allocator - // and call allocator's release function in the deleter. - - // NOTE2: Use inplace new to allocate - // This is used to get rid of warning when deleting a virtual - // class with non-virtual destructor. - // We are fine here as we captured the right deleter during construction. - // This is also the right way to get storage type for an object pool. - StorageType* data = new StorageType(); - new (data) T(std::forward(args)...); - return reinterpret_cast(data); - } - - static Object::FDeleter Deleter() { return Deleter_; } - - private: - static void Deleter_(Object* objptr) { - // NOTE: this is important to cast back to T* - // because objptr and tptr may not be the same - // depending on how sub-class allocates the space. - T* tptr = static_cast(objptr); - // It is important to do tptr->T::~T(), - // so that we explicitly call the specific destructor - // instead of tptr->~T(), which could mean the intention - // call a virtual destructor(which may not be available and is not required). - tptr->T::~T(); - delete reinterpret_cast(tptr); - } - }; - - // Array handler that uses new/delete. - template - class ArrayHandler { - public: - using StorageType = typename std::aligned_storage::type; - // for now only support elements that aligns with array header. - static_assert(alignof(ArrayType) % alignof(ElemType) == 0 && - sizeof(ArrayType) % alignof(ElemType) == 0, - "element alignment constraint"); - - template - static ArrayType* New(SimpleObjAllocator*, size_t num_elems, Args&&... args) { - // NOTE: the first argument is not needed for ArrayObjAllocator - // It is reserved for special allocators that needs to recycle - // the object to itself (e.g. in the case of object pool). - // - // In the case of an object pool, an allocator needs to create - // a special chunk memory that hides reference to the allocator - // and call allocator's release function in the deleter. - // NOTE2: Use inplace new to allocate - // This is used to get rid of warning when deleting a virtual - // class with non-virtual destructor. - // We are fine here as we captured the right deleter during construction. - // This is also the right way to get storage type for an object pool. - size_t unit = sizeof(StorageType); - size_t requested_size = num_elems * sizeof(ElemType) + sizeof(ArrayType); - size_t num_storage_slots = (requested_size + unit - 1) / unit; - StorageType* data = new StorageType[num_storage_slots]; - new (data) ArrayType(std::forward(args)...); - return reinterpret_cast(data); - } - - static Object::FDeleter Deleter() { return Deleter_; } - - private: - static void Deleter_(Object* objptr) { - // NOTE: this is important to cast back to ArrayType* - // because objptr and tptr may not be the same - // depending on how sub-class allocates the space. - ArrayType* tptr = static_cast(objptr); - // It is important to do tptr->ArrayType::~ArrayType(), - // so that we explicitly call the specific destructor - // instead of tptr->~ArrayType(), which could mean the intention - // call a virtual destructor(which may not be available and is not required). - tptr->ArrayType::~ArrayType(); - StorageType* p = reinterpret_cast(tptr); - delete[] p; - } - }; -}; - -template -inline ObjectPtr make_object(Args&&... args) { - return SimpleObjAllocator().make_object(std::forward(args)...); -} -template -inline ObjectPtr make_inplace_array_object(size_t num_elems, Args&&... args) { - return SimpleObjAllocator().make_inplace_array(num_elems, - std::forward(args)...); -} +using tvm::ffi::FObjectDeleter; +using tvm::ffi::make_inplace_array_object; +using tvm::ffi::make_object; } // namespace runtime } // namespace tvm diff --git a/include/tvm/runtime/memory/memory_manager.h b/include/tvm/runtime/memory/memory_manager.h index ab1e6b5c9f6d..537beeb8fa9d 100644 --- a/include/tvm/runtime/memory/memory_manager.h +++ b/include/tvm/runtime/memory/memory_manager.h @@ -170,19 +170,12 @@ class StorageObj : public Object { TVM_DLL NDArray AllocNDArrayScoped(int64_t offset, ShapeTuple shape, DLDataType dtype, String scope = "global"); - /*! \brief The deleter for an NDArray when allocated from underlying storage. */ - static void ScopedDeleter(Object* ptr); - - /*! \brief The deleter for an NDArray when allocated from underlying storage. */ - static void Deleter(Object* ptr); - ~StorageObj() { if (allocator) { allocator->Free(buffer); } } - static constexpr const uint32_t _type_index = TypeIndex::kDynamic; static constexpr const char* _type_key = "vm.Storage"; TVM_DECLARE_FINAL_OBJECT_INFO(StorageObj, Object); }; diff --git a/include/tvm/runtime/module.h b/include/tvm/runtime/module.h index 308c936d624a..69e8d4283ede 100644 --- a/include/tvm/runtime/module.h +++ b/include/tvm/runtime/module.h @@ -27,6 +27,7 @@ #define TVM_RUNTIME_MODULE_H_ #include +#include #include #include #include @@ -41,6 +42,8 @@ namespace tvm { namespace runtime { +using PackedFunc = ffi::Function; + /*! * \brief Property of runtime module * We classify the property of runtime module into the following categories. @@ -71,7 +74,6 @@ enum ModulePropertyMask : int { }; class ModuleNode; -class PackedFunc; /*! * \brief Module container of TVM. @@ -253,11 +255,11 @@ class TVM_DLL ModuleNode : public Object { virtual bool ImplementsFunction(const String& name, bool query_imports = false); // integration with the existing components. - static constexpr const uint32_t _type_index = TypeIndex::kRuntimeModule; + static constexpr const uint32_t _type_index = ffi::TypeIndex::kTVMFFIModule; static constexpr const char* _type_key = "runtime.Module"; // NOTE: ModuleNode can still be sub-classed // - TVM_DECLARE_FINAL_OBJECT_INFO(ModuleNode, Object); + TVM_FFI_DECLARE_STATIC_OBJECT_INFO(ModuleNode, Object); protected: friend class Module; @@ -278,6 +280,11 @@ class TVM_DLL ModuleNode : public Object { */ TVM_DLL bool RuntimeEnabled(const String& target); +// implementation of Module::GetFunction +inline PackedFunc Module::GetFunction(const String& name, bool query_imports) { + return (*this)->GetFunction(name, query_imports); +} + /*! \brief namespace for constant symbols */ namespace symbol { /*! \brief A PackedFunc that retrieves exported metadata. */ diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/ndarray.h index fef61a753103..82c9b229ab90 100644 --- a/include/tvm/runtime/ndarray.h +++ b/include/tvm/runtime/ndarray.h @@ -24,11 +24,13 @@ #ifndef TVM_RUNTIME_NDARRAY_H_ #define TVM_RUNTIME_NDARRAY_H_ +#include #include #include #include #include #include +#include #include #include @@ -38,43 +40,39 @@ #include namespace tvm { - -// alias DLDevice -using Device = DLDevice; - namespace runtime { +using ffi::GetDataSize; +using ffi::IsAligned; +using ffi::IsContiguous; + /*! * \brief Managed NDArray. * The array is backed by reference counted blocks. */ -class NDArray : public ObjectRef { +class NDArray : public tvm::ffi::NDArray { public: - /*! \brief ContainerBase used to back the TVMArrayHandle */ - class ContainerBase; - /*! \brief NDArray internal container type */ - class Container; - /*! \brief Container type for Object system. */ - using ContainerType = Container; - /*! \brief default constructor */ - NDArray() {} + using Container = ffi::NDArrayObj; + NDArray() = default; /*! * \brief constructor. * \param data ObjectPtr to the data container. */ - explicit NDArray(ObjectPtr data) : ObjectRef(data) {} + explicit NDArray(ObjectPtr data) : tvm::ffi::NDArray(data) {} + NDArray(ffi::NDArray&& other) : tvm::ffi::NDArray(std::move(other)) {} // NOLINT(*) + NDArray(const ffi::NDArray& other) : tvm::ffi::NDArray(other) {} // NOLINT(*) - /*! \brief reset the content of NDArray to be nullptr */ - inline void reset(); - /*! - * \return the reference counter - * \note this number is approximate in multi-threaded setting. - */ - inline int use_count() const; - /*! \return Pointer to content of DLTensor */ - inline const DLTensor* operator->() const; - /*! \return Whether the tensor is contiguous */ - inline bool IsContiguous() const; + ShapeTuple Shape() const { return this->shape(); } + runtime::DataType DataType() const { return runtime::DataType(this->dtype()); } + + // DLPack handling + static NDArray FromDLPack(DLManagedTensor* tensor) { + return tvm::ffi::NDArray::FromDLPack(tensor, kAllocAlignment, true); + } + + static NDArray FromDLPackVersioned(DLManagedTensorVersioned* tensor) { + return tvm::ffi::NDArray::FromDLPackVersioned(tensor, kAllocAlignment, true); + } /*! * \brief Copy data content from another array. * \param other The source array to be copied from. @@ -147,14 +145,8 @@ class NDArray : public ObjectRef { * outside the bounds of the current array, this function will * raise an exception. */ - TVM_DLL NDArray CreateView(ShapeTuple shape, DLDataType dtype, uint64_t relative_byte_offset = 0); - - /*! - * \brief Create a reference view of NDArray that - * represents as DLManagedTensor. - * \return A DLManagedTensor - */ - TVM_DLL DLManagedTensor* ToDLPack() const; + TVM_DLL NDArray CreateView(ShapeTuple shape, DLDataType dtype, + uint64_t relative_byte_offset = 0) const; /*! * \brief Create an empty NDArray. * \param shape The shape of the new array. @@ -165,37 +157,6 @@ class NDArray : public ObjectRef { */ TVM_DLL static NDArray Empty(ShapeTuple shape, DLDataType dtype, Device dev, Optional mem_scope = NullOpt); - /*! - * \brief Create a NDArray backed by an external DLTensor without memory copying. - * - * If DLTensor is not contiguous or has bad aligned data, It fails. - * This allows us to create a NDArray using the memory - * allocated by an external source. Responsibility for memory - * retaining lies with the external source. - * \param dl_tensor The DLTensor for NDArray base. - * \return The created NDArray view. - */ - TVM_DLL static NDArray FromExternalDLTensor(const DLTensor& dl_tensor); - /*! - * \brief Create new NDArray, data is copied from DLTensor. - * - * \param dl_tensor The DLTensor to copy from. - * \param dev device location of the created NDArray. - * \return The created NDArray view. - */ - TVM_DLL static NDArray NewFromDLTensor(DLTensor* dl_tensor, const Device& dev); - /*! - * \brief Create a NDArray backed by a dlpack tensor. - * - * This allows us to create a NDArray using the memory - * allocated by an external deep learning framework - * that is DLPack compatible. - * - * The memory is retained until the NDArray went out of scope. - * \param tensor The DLPack tensor to copy from. - * \return The created NDArray view. - */ - TVM_DLL static NDArray FromDLPack(DLManagedTensor* tensor); /*! * \brief Function to copy data from one array to another. * \param from The source array. @@ -205,47 +166,9 @@ class NDArray : public ObjectRef { TVM_DLL static void CopyFromTo(const DLTensor* from, DLTensor* to, TVMStreamHandle stream = nullptr); - TVM_DLL ShapeTuple Shape() const; - TVM_DLL runtime::DataType DataType() const; - /*! - * \brief Check conditions for construction NDArray over DLTensor without copying. - * There are three conditions to check: - * 1. Destination device is the same as DLTensor device - * 2. Destination device id is the same as DLTensor device id - * 3. Memory in DLTensor is aligned as expected for NDArray - * \param tensor the DLTensor. - * \param dev destination device. - * \return true if all conditions are satisfied. - */ - TVM_DLL static bool AbilityOfZeroCopyForDLTensor(DLTensor* tensor, const Device& dev); - // internal namespace struct Internal; - private: - TVM_DLL static bool IsAligned(const DLTensor& tensor); - protected: - friend class TVMPODValue_; - template - friend class TVMPODValue_CRTP_; - friend class TVMRetValue; - friend class TVMArgsSetter; - /*! - * \brief Get mutable internal container pointer. - * \return a mutable container pointer. - */ - inline Container* get_mutable() const; - // Helper functions for FFI handling. - /*! - * \brief Construct NDArray's Data field from array handle in FFI. - * \param handle The array handle. - * \return The corresponding ObjectPtr to the constructed container object. - * - * \note We keep a special calling convention for NDArray by passing - * ContainerBase pointer in FFI. - * As a result, the argument is compatible to DLTensor*. - */ - inline static ObjectPtr FFIDataFromHandle(TVMArrayHandle handle); /*! * \brief DecRef resource managed by an FFI array handle. * \param handle The array handle. @@ -266,185 +189,48 @@ class NDArray : public ObjectRef { */ inline bool SaveDLTensor(dmlc::Stream* strm, const DLTensor* tensor); -/*! - * \brief The container base structure - * contains all the fields except for the Object header. - * - * \note We explicitly declare this structure in order to pass - * PackedFunc argument using ContainerBase*. - */ -class NDArray::ContainerBase { - public: - /*! - * \brief The corresponding dl_tensor field. - * \note it is important that the first field is DLTensor - * So that this data structure is DLTensor compatible. - * The head ptr of this struct can be viewed as DLTensor*. - */ - DLTensor dl_tensor; - - /*! - * \brief additional context, reserved for recycling - * \note We can attach additional content here - * which the current container depend on - * (e.g. reference to original memory when creating views). - */ - void* manager_ctx{nullptr}; - - protected: - /*! - * \brief The shape container, - * can be used for shape data. - */ - ShapeTuple shape_; -}; - -/*! - * \brief Object container class that backs NDArray. - * \note do not use this function directly, use NDArray. - */ -class NDArray::Container : public Object, public NDArray::ContainerBase { - public: - /*! \brief default constructor */ - Container() { - // Initialize the type index. - type_index_ = Container::RuntimeTypeIndex(); - dl_tensor.data = nullptr; - dl_tensor.ndim = 0; - dl_tensor.shape = nullptr; - dl_tensor.strides = nullptr; - dl_tensor.byte_offset = 0; - } - - Container(void* data, ShapeTuple shape, DLDataType dtype, Device dev) { - // Initialize the type index. - type_index_ = Container::RuntimeTypeIndex(); - dl_tensor.data = data; - shape_ = std::move(shape); - dl_tensor.ndim = static_cast(shape_.size()); - dl_tensor.shape = const_cast(shape_.data()); - dl_tensor.dtype = dtype; - dl_tensor.strides = nullptr; - dl_tensor.byte_offset = 0; - dl_tensor.device = dev; - } - /*! - * \brief Set the deleter field. - * \param deleter The deleter. - */ - void SetDeleter(FDeleter deleter) { deleter_ = deleter; } - - // Expose DecRef and IncRef as public function - // NOTE: they are only for developer purposes only. - using Object::DecRef; - using Object::IncRef; - - // Information for object protocol. - static constexpr const uint32_t _type_index = TypeIndex::kRuntimeNDArray; - static constexpr const uint32_t _type_child_slots = 0; - static constexpr const uint32_t _type_child_slots_can_overflow = true; - static constexpr const char* _type_key = "runtime.NDArray"; - TVM_DECLARE_BASE_OBJECT_INFO(NDArray::Container, Object); - - protected: - friend class RPCWrappedFunc; - friend class NDArray; -}; - -// implementations of inline functions -/*! - * \brief return the size of data the DLTensor hold, in term of number of bytes - * - * \param arr the input DLTensor - * \return number of bytes of data in the DLTensor. - */ -inline size_t GetDataSize(const DLTensor& arr) { - size_t size = 1; - for (tvm_index_t i = 0; i < arr.ndim; ++i) { - size *= static_cast(arr.shape[i]); - } - size *= (arr.dtype.bits * arr.dtype.lanes + 7) / 8; - return size; -} - -/*! - * \brief check if a DLTensor is contiguous. - * \param arr The input DLTensor. - * \return The check result. - */ -static inline bool IsContiguous(const DLTensor& arr) { - if (arr.strides == nullptr) return true; - int64_t expected_stride = 1; - for (int32_t i = arr.ndim; i != 0; --i) { - int32_t k = i - 1; - if (arr.shape[k] == 1) { - // Skip stride check if shape[k] is 1, where the dimension is contiguous - // regardless of the value of stride. - // - // For example, PyTorch will normalize stride to 1 if shape is 1 when exporting - // to DLPack. - // More context: https://github.com/pytorch/pytorch/pull/83158 - continue; - } - if (arr.strides[k] != expected_stride) return false; - expected_stride *= arr.shape[k]; - } - return true; -} - -inline bool NDArray::IsContiguous() const { - return ::tvm::runtime::IsContiguous(get_mutable()->dl_tensor); -} - inline void NDArray::CopyFrom(const DLTensor* other) { ICHECK(data_ != nullptr); - CopyFromTo(other, &(get_mutable()->dl_tensor)); + CopyFromTo(other, get_mutable()); } inline void NDArray::CopyFrom(const NDArray& other) { ICHECK(data_ != nullptr); ICHECK(other.data_ != nullptr); - CopyFromTo(&(other.get_mutable()->dl_tensor), &(get_mutable()->dl_tensor)); + CopyFromTo(other.get_mutable(), get_mutable()); } inline void NDArray::CopyTo(DLTensor* other) const { ICHECK(data_ != nullptr); - CopyFromTo(&(get_mutable()->dl_tensor), other); + CopyFromTo(get_mutable(), other); } inline void NDArray::CopyTo(const NDArray& other) const { ICHECK(data_ != nullptr); ICHECK(other.data_ != nullptr); - CopyFromTo(&(get_mutable()->dl_tensor), &(other.get_mutable()->dl_tensor)); -} - -inline int NDArray::use_count() const { return data_.use_count(); } - -inline const DLTensor* NDArray::operator->() const { return &(get_mutable()->dl_tensor); } - -inline NDArray::Container* NDArray::get_mutable() const { - return static_cast(data_.get()); -} - -inline ObjectPtr NDArray::FFIDataFromHandle(TVMArrayHandle handle) { - return GetObjectPtr( - static_cast(reinterpret_cast(handle))); + CopyFromTo(get_mutable(), other.get_mutable()); } inline TVMArrayHandle NDArray::FFIGetHandle(const ObjectRef& nd) { // NOTE: it is necessary to cast to container then to base // so that the FFI handle uses the ContainerBase address. - auto ptr = reinterpret_cast(static_cast( - static_cast(const_cast(nd.get())))); + auto ptr = reinterpret_cast( + TVMFFINDArrayGetDLTensorPtr(static_cast(const_cast(nd.get())))); return ptr; } -inline void NDArray::FFIDecRef(TVMArrayHandle handle) { - static_cast(reinterpret_cast(handle))->DecRef(); +inline TVMArrayHandle ObjectHandleToTVMArrayHandle(Object* handle) { + return reinterpret_cast( + TVMFFINDArrayGetDLTensorPtr(static_cast(handle))); } -inline Object* TVMArrayHandleToObjectHandle(TVMArrayHandle handle) { - return static_cast(reinterpret_cast(handle)); +inline Object* TVMArrayHandleToObjectHandle(void* handle) { + // NOTE: legacy patch here for TFM FFI + return reinterpret_cast(reinterpret_cast(handle) - sizeof(TVMFFIObject)); +} + +inline void NDArray::FFIDecRef(TVMArrayHandle handle) { + ffi::details::ObjectUnsafe::DecRefObjectHandle(TVMArrayHandleToObjectHandle(handle)); } /*! \brief Magic number for NDArray file */ diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index eef687683e39..d9e1ac25177d 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -23,709 +23,57 @@ #ifndef TVM_RUNTIME_OBJECT_H_ #define TVM_RUNTIME_OBJECT_H_ +#include +#include #include -#include +#include -#include -#include #include -/*! - * \brief Whether or not use atomic reference counter. - * If the reference counter is not atomic, - * an object cannot be owned by multiple threads. - * We can, however, move an object across threads - */ -#ifndef TVM_OBJECT_ATOMIC_REF_COUNTER -#define TVM_OBJECT_ATOMIC_REF_COUNTER 1 -#endif - -#if TVM_OBJECT_ATOMIC_REF_COUNTER -#include -#endif // TVM_OBJECT_ATOMIC_REF_COUNTER - namespace tvm { namespace runtime { +using tvm::ffi::Object; +using tvm::ffi::ObjectPtr; +using tvm::ffi::ObjectPtrEqual; +using tvm::ffi::ObjectPtrHash; +using tvm::ffi::ObjectRef; + +using tvm::ffi::Downcast; +using tvm::ffi::GetObjectPtr; +using tvm::ffi::GetRef; + /*! * \brief Namespace for the list of type index. * \note Use struct so that we have to use TypeIndex::ENumName to refer to * the constant, but still able to use enum. */ -struct TypeIndex { - enum { - /*! \brief Root object type. */ - kRoot = 0, - // Standard static index assignments, - // Frontends can take benefit of these constants. - /*! \brief runtime::Module. */ - kRuntimeModule = 1, - /*! \brief runtime::NDArray. */ - kRuntimeNDArray = 2, - /*! \brief runtime::String. */ - kRuntimeString = 3, - /*! \brief runtime::Array. */ - kRuntimeArray = 4, - /*! \brief runtime::Map. */ - kRuntimeMap = 5, - /*! \brief runtime::ShapeTuple. */ - kRuntimeShapeTuple = 6, - /*! \brief runtime::PackedFunc. */ - kRuntimePackedFunc = 7, - /*! \brief runtime::DRef for disco distributed runtime */ - kRuntimeDiscoDRef = 8, - /*! \brief runtime::RPCObjectRef */ - kRuntimeRPCObjectRef = 9, - // static assignments that may subject to change. - kStaticIndexEnd = 10, - /*! - * \brief Type index is allocated during runtime, keeping it as - * constant for now to ensure compatibility across versions - */ - kDynamic = 12 - }; -}; // namespace TypeIndex - -/*! - * \brief base class of all object containers. - * - * Sub-class of objects should declare the following static constexpr fields: - * - * - _type_index: - * Static type index of the object, if assigned to TypeIndex::kDynamic - * the type index will be assigned during runtime. - * Runtime type index can be accessed by ObjectType::TypeIndex(); - * - _type_key: - * The unique string identifier of the type. - * - _type_final: - * Whether the type is terminal type(there is no subclass of the type in the object system). - * This field is automatically set by macro TVM_DECLARE_FINAL_OBJECT_INFO - * It is still OK to sub-class a terminal object type T and construct it using make_object. - * But IsInstance check will only show that the object type is T(instead of the sub-class). - * - * The following two fields are necessary for base classes that can be sub-classed. - * - * - _type_child_slots: - * Number of reserved type index slots for child classes. - * Used for runtime optimization for type checking in IsInstance. - * If an object's type_index is within range of [type_index, type_index + _type_child_slots] - * Then the object can be quickly decided as sub-class of the current object class. - * If not, a fallback mechanism is used to check the global type table. - * Recommendation: set to estimate number of children needed. - * - _type_child_slots_can_overflow: - * Whether we can add additional child classes even if the number of child classes - * exceeds the _type_child_slots. A fallback mechanism to check global type table will be - * used. Recommendation: set to false for optimal runtime speed if we know exact number of children. - * - * Two macros are used to declare helper functions in the object: - * - Use TVM_DECLARE_BASE_OBJECT_INFO for object classes that can be sub-classed. - * - Use TVM_DECLARE_FINAL_OBJECT_INFO for object classes that cannot be sub-classed. - * - * New objects can be created using make_object function. - * Which will automatically populate the type_index and deleter of the object. - * - * \sa make_object - * \sa ObjectPtr - * \sa ObjectRef - * - * \code - * - * // Create a base object - * class BaseObj : public Object { - * public: - * // object fields - * int field0; - * - * // object properties - * static constexpr const uint32_t _type_index = TypeIndex::kDynamic; - * static constexpr const char* _type_key = "test.BaseObj"; - * TVM_DECLARE_BASE_OBJECT_INFO(BaseObj, Object); - * }; - * - * class LeafObj : public BaseObj { - * public: - * // fields - * int child_field0; - * // object properties - * static constexpr const uint32_t _type_index = TypeIndex::kDynamic; - * static constexpr const char* _type_key = "test.LeafObj"; - * TVM_DECLARE_BASE_OBJECT_INFO(LeafObj, Object); - * }; - * - * // The following code should be put into a cc file. - * TVM_REGISTER_OBJECT_TYPE(BaseObj); - * TVM_REGISTER_OBJECT_TYPE(LeafObj); - * - * // Usage example. - * void TestObjects() { - * // create an object - * ObjectRef leaf_ref(make_object()); - * // cast to a specific instance - * const LeafObj* leaf_ptr = leaf_ref.as(); - * ICHECK(leaf_ptr != nullptr); - * // can also cast to the base class. - * ICHECK(leaf_ref.as() != nullptr); - * } - * - * \endcode - */ -class TVM_DLL Object { - public: - /*! - * \brief Object deleter - * \param self pointer to the Object. - */ - typedef void (*FDeleter)(Object* self); - /*! \return The internal runtime type index of the object. */ - uint32_t type_index() const { return type_index_; } - /*! - * \return the type key of the object. - * \note this operation is expensive, can be used for error reporting. - */ - std::string GetTypeKey() const { return TypeIndex2Key(type_index_); } - /*! - * \return A hash value of the return of GetTypeKey. - */ - size_t GetTypeKeyHash() const { return TypeIndex2KeyHash(type_index_); } - /*! - * Check if the object is an instance of TargetType. - * \tparam TargetType The target type to be checked. - * \return Whether the target type is true. - */ - template - inline bool IsInstance() const; - /*! - * \return Whether the cell has only one reference - * \note We use stl style naming to be consistent with known API in shared_ptr. - */ - inline bool unique() const; - /*! - * \brief Get the type key of the corresponding index from runtime. - * \param tindex The type index. - * \return the result. - */ - static std::string TypeIndex2Key(uint32_t tindex); - /*! - * \brief Get the type key hash of the corresponding index from runtime. - * \param tindex The type index. - * \return the related key-hash. - */ - static size_t TypeIndex2KeyHash(uint32_t tindex); - /*! - * \brief Get the type index of the corresponding key from runtime. - * \param key The type key. - * \return the result. - */ - static uint32_t TypeKey2Index(const std::string& key); - -#if TVM_OBJECT_ATOMIC_REF_COUNTER - using RefCounterType = std::atomic; -#else - using RefCounterType = int32_t; -#endif - - static constexpr const char* _type_key = "runtime.Object"; - - static uint32_t _GetOrAllocRuntimeTypeIndex() { return TypeIndex::kRoot; } - static uint32_t RuntimeTypeIndex() { return TypeIndex::kRoot; } - - // Default object type properties for sub-classes - static constexpr bool _type_final = false; - static constexpr uint32_t _type_child_slots = 0; - static constexpr bool _type_child_slots_can_overflow = true; - // member information - static constexpr bool _type_has_method_visit_attrs = true; - static constexpr bool _type_has_method_sequal_reduce = false; - static constexpr bool _type_has_method_shash_reduce = false; - // NOTE: the following field is not type index of Object - // but was intended to be used by sub-classes as default value. - // The type index of Object is TypeIndex::kRoot - static constexpr uint32_t _type_index = TypeIndex::kDynamic; - - // Default constructor and copy constructor - Object() {} - // Override the copy and assign constructors to do nothing. - // This is to make sure only contents, but not deleter and ref_counter - // are copied when a child class copies itself. - // This will enable us to use make_object(*obj_ptr) - // to copy an existing object. - Object(const Object& other) { // NOLINT(*) - } - Object(Object&& other) { // NOLINT(*) - } - Object& operator=(const Object& other) { // NOLINT(*) - return *this; - } - Object& operator=(Object&& other) { // NOLINT(*) - return *this; - } - - protected: - // The fields of the base object cell. - /*! \brief Type index(tag) that indicates the type of the object. */ - uint32_t type_index_{0}; - /*! \brief The internal reference counter */ - RefCounterType ref_counter_{0}; - /*! - * \brief deleter of this object to enable customized allocation. - * If the deleter is nullptr, no deletion will be performed. - * The creator of the object must always set the deleter field properly. - */ - FDeleter deleter_ = nullptr; - // Invariant checks. - static_assert(sizeof(int32_t) == sizeof(RefCounterType) && - alignof(int32_t) == sizeof(RefCounterType), - "RefCounter ABI check."); - - /*! - * \brief Get the type index using type key. - * - * When the function is first time called for a type, - * it will register the type to the type table in the runtime. - * If the static_tindex is TypeIndex::kDynamic, the function will - * allocate a runtime type index. - * Otherwise, we will populate the type table and return the static index. - * - * \param key the type key. - * \param static_tindex The current _type_index field. - * can be TypeIndex::kDynamic. - * \param parent_tindex The index of the parent. - * \param type_child_slots Number of slots reserved for its children. - * \param type_child_slots_can_overflow Whether to allow child to overflow the slots. - * \return The allocated type index. - */ - static uint32_t GetOrAllocRuntimeTypeIndex(const std::string& key, uint32_t static_tindex, - uint32_t parent_tindex, uint32_t type_child_slots, - bool type_child_slots_can_overflow); - - // reference counter related operations - /*! \brief developer function, increases reference counter. */ - inline void IncRef(); - /*! - * \brief developer function, decrease reference counter. - * \note The deleter will be called when ref_counter_ becomes zero. - */ - inline void DecRef(); - - private: - /*! - * \return The usage count of the cell. - * \note We use stl style naming to be consistent with known API in shared_ptr. - */ - inline int use_count() const; - /*! - * \brief Check of this object is derived from the parent. - * \param parent_tindex The parent type index. - * \return The derivation results. - */ - bool DerivedFrom(uint32_t parent_tindex) const; - // friend classes - template - friend class ObjAllocatorBase; - template - friend class ObjectPtr; - friend class TVMRetValue; - friend class ObjectInternal; +enum TypeIndex : int32_t { + // Standard static index assignments, + // Frontends can take benefit of these constants. + + /*! \brief runtime::Module. */ + kRuntimeModule = TVMFFITypeIndex::kTVMFFIModule, + /*! \brief runtime::NDArray. */ + kRuntimeNDArray = TVMFFITypeIndex::kTVMFFINDArray, + /*! \brief runtime::ShapeTuple. */ + kRuntimeShapeTuple = TVMFFITypeIndex::kTVMFFIShape, + // Extra builtin static index here + kCustomStaticIndex = TVMFFITypeIndex::kTVMFFIStaticObjectEnd, + /*! \brief runtime::PackedFunc. */ + kRuntimePackedFunc = kCustomStaticIndex + 1, + /*! \brief runtime::DRef for disco distributed runtime */ + kRuntimeDiscoDRef = kCustomStaticIndex + 2, + /*! \brief runtime::RPCObjectRef */ + kRuntimeRPCObjectRef = kCustomStaticIndex + 3, + // custom builtin + kRuntimeString, + kRuntimeMap, + kRuntimeArray, + // static assignments that may subject to change. + kStaticIndexEnd, }; -/*! - * \brief Get a reference type from a raw object ptr type - * - * It is always important to get a reference type - * if we want to return a value as reference or keep - * the object alive beyond the scope of the function. - * - * \param ptr The object pointer - * \tparam RefType The reference type - * \tparam ObjectType The object type - * \return The corresponding RefType - */ -template -inline ObjectRefType GetRef(const ObjectType* ptr); - -/*! - * \brief Downcast a base reference type to a more specific type. - * - * \param ref The input reference - * \return The corresponding SubRef. - * \tparam SubRef The target specific reference type. - * \tparam BaseRef the current reference type. - */ -template -inline SubRef Downcast(BaseRef ref); - -/*! - * \brief A custom smart pointer for Object. - * \tparam T the content data type. - * \sa make_object - */ -template -class ObjectPtr { - public: - /*! \brief default constructor */ - ObjectPtr() {} - /*! \brief default constructor */ - ObjectPtr(std::nullptr_t) {} // NOLINT(*) - /*! - * \brief copy constructor - * \param other The value to be moved - */ - ObjectPtr(const ObjectPtr& other) // NOLINT(*) - : ObjectPtr(other.data_) {} - /*! - * \brief copy constructor - * \param other The value to be moved - */ - template - ObjectPtr(const ObjectPtr& other) // NOLINT(*) - : ObjectPtr(other.data_) { - static_assert(std::is_base_of::value, - "can only assign of child class ObjectPtr to parent"); - } - /*! - * \brief move constructor - * \param other The value to be moved - */ - ObjectPtr(ObjectPtr&& other) // NOLINT(*) - : data_(other.data_) { - other.data_ = nullptr; - } - /*! - * \brief move constructor - * \param other The value to be moved - */ - template - ObjectPtr(ObjectPtr&& other) // NOLINT(*) - : data_(other.data_) { - static_assert(std::is_base_of::value, - "can only assign of child class ObjectPtr to parent"); - other.data_ = nullptr; - } - /*! \brief destructor */ - ~ObjectPtr() { this->reset(); } - /*! - * \brief Swap this array with another Object - * \param other The other Object - */ - void swap(ObjectPtr& other) { // NOLINT(*) - std::swap(data_, other.data_); - } - /*! - * \return Get the content of the pointer - */ - T* get() const { return static_cast(data_); } - /*! - * \return The pointer - */ - T* operator->() const { return get(); } - /*! - * \return The reference - */ - T& operator*() const { // NOLINT(*) - return *get(); - } - /*! - * \brief copy assignment - * \param other The value to be assigned. - * \return reference to self. - */ - ObjectPtr& operator=(const ObjectPtr& other) { // NOLINT(*) - // takes in plane operator to enable copy elison. - // copy-and-swap idiom - ObjectPtr(other).swap(*this); // NOLINT(*) - return *this; - } - /*! - * \brief move assignment - * \param other The value to be assigned. - * \return reference to self. - */ - ObjectPtr& operator=(ObjectPtr&& other) { // NOLINT(*) - // copy-and-swap idiom - ObjectPtr(std::move(other)).swap(*this); // NOLINT(*) - return *this; - } - /*! - * \brief nullptr check - * \return result of comparison of internal pointer with nullptr. - */ - explicit operator bool() const { return get() != nullptr; } - /*! \brief reset the content of ptr to be nullptr */ - void reset() { - if (data_ != nullptr) { - data_->DecRef(); - data_ = nullptr; - } - } - /*! \return The use count of the ptr, for debug purposes */ - int use_count() const { return data_ != nullptr ? data_->use_count() : 0; } - /*! \return whether the reference is unique */ - bool unique() const { return data_ != nullptr && data_->use_count() == 1; } - /*! \return Whether two ObjectPtr do not equal each other */ - bool operator==(const ObjectPtr& other) const { return data_ == other.data_; } - /*! \return Whether two ObjectPtr equals each other */ - bool operator!=(const ObjectPtr& other) const { return data_ != other.data_; } - /*! \return Whether the pointer is nullptr */ - bool operator==(std::nullptr_t null) const { return data_ == nullptr; } - /*! \return Whether the pointer is not nullptr */ - bool operator!=(std::nullptr_t null) const { return data_ != nullptr; } - - private: - /*! \brief internal pointer field */ - Object* data_{nullptr}; - /*! - * \brief constructor from Object - * \param data The data pointer - */ - explicit ObjectPtr(Object* data) : data_(data) { - if (data != nullptr) { - data_->IncRef(); - } - } - /*! - * \brief Move an ObjectPtr from an RValueRef argument. - * \param ref The rvalue reference. - * \return the moved result. - */ - static ObjectPtr MoveFromRValueRefArg(Object** ref) { - ObjectPtr ptr; - ptr.data_ = *ref; - *ref = nullptr; - return ptr; - } - // friend classes - friend class Object; - friend class ObjectRef; - friend struct ObjectPtrHash; - template - friend class ObjectPtr; - template - friend class ObjAllocatorBase; - friend class TVMPODValue_; - friend class TVMArgsSetter; - friend class TVMRetValue; - friend class TVMArgValue; - friend class TVMMovableArgValue_; - template - friend ObjectRefType GetRef(const ObjType* ptr); - template - friend ObjectPtr GetObjectPtr(ObjType* ptr); -}; - -// Forward declaration, to prevent circular includes. -template -class Optional; - -/*! \brief Base class of all object reference */ -class ObjectRef { - public: - /*! \brief default constructor */ - ObjectRef() = default; - /*! \brief Constructor from existing object ptr */ - explicit ObjectRef(ObjectPtr data) : data_(data) {} - /*! - * \brief Comparator - * \param other Another object ref. - * \return the compare result. - */ - bool same_as(const ObjectRef& other) const { return data_ == other.data_; } - /*! - * \brief Comparator - * \param other Another object ref. - * \return the compare result. - */ - bool operator==(const ObjectRef& other) const { return data_ == other.data_; } - /*! - * \brief Comparator - * \param other Another object ref. - * \return the compare result. - */ - bool operator!=(const ObjectRef& other) const { return data_ != other.data_; } - /*! - * \brief Comparator - * \param other Another object ref by address. - * \return the compare result. - */ - bool operator<(const ObjectRef& other) const { return data_.get() < other.data_.get(); } - /*! - * \return whether the object is defined(not null). - */ - bool defined() const { return data_ != nullptr; } - /*! \return the internal object pointer */ - const Object* get() const { return data_.get(); } - /*! \return the internal object pointer */ - const Object* operator->() const { return get(); } - /*! \return whether the reference is unique */ - bool unique() const { return data_.unique(); } - /*! \return The use count of the ptr, for debug purposes */ - int use_count() const { return data_.use_count(); } - - /*! - * \brief Try to downcast the internal Object to a - * raw pointer of a corresponding type. - * - * The function will return a nullptr if the cast failed. - * - * if (const AddNode *ptr = node_ref.as()) { - * // This is an add node - * } - * - * \tparam ObjectType the target type, must be a subtype of Object - */ - template >> - inline const ObjectType* as() const; - - /*! - * \brief Try to downcast the ObjectRef to a - * Optional of the requested type. - * - * The function will return a NullOpt if the cast failed. - * - * if (Optional opt = node_ref.as()) { - * // This is an add node - * } - * - * \note While this method is declared in , - * the implementation is in to - * prevent circular includes. This additional include file is only - * required in compilation units that uses this method. - * - * \tparam ObjectRefType the target type, must be a subtype of ObjectRef - */ - template >> - inline Optional as() const; - - /*! \brief type indicate the container type. */ - using ContainerType = Object; - // Default type properties for the reference class. - static constexpr bool _type_is_nullable = true; - - protected: - /*! \brief Internal pointer that backs the reference. */ - ObjectPtr data_; - /*! \return return a mutable internal ptr, can be used by sub-classes. */ - Object* get_mutable() const { return data_.get(); } - /*! - * \brief Internal helper function downcast a ref without check. - * \note Only used for internal dev purposes. - * \tparam T The target reference type. - * \return The casted result. - */ - template - static T DowncastNoCheck(ObjectRef ref) { - return T(std::move(ref.data_)); - } - /*! - * \brief Clear the object ref data field without DecRef - * after we successfully moved the field. - * \param ref The reference data. - */ - static void FFIClearAfterMove(ObjectRef* ref) { ref->data_.data_ = nullptr; } - /*! - * \brief Internal helper function get data_ as ObjectPtr of ObjectType. - * \note only used for internal dev purpose. - * \tparam ObjectType The corresponding object type. - * \return the corresponding type. - */ - template - static ObjectPtr GetDataPtr(const ObjectRef& ref) { - return ObjectPtr(ref.data_.data_); - } - // friend classes. - friend struct ObjectPtrHash; - friend class TVMRetValue; - friend class TVMArgsSetter; - friend class ObjectInternal; - template - friend SubRef Downcast(BaseRef ref); -}; - -/*! - * \brief Get an object ptr type from a raw object ptr. - * - * \param ptr The object pointer - * \tparam BaseType The reference type - * \tparam ObjectType The object type - * \return The corresponding RefType - */ -template -inline ObjectPtr GetObjectPtr(ObjectType* ptr); - -/*! \brief ObjectRef hash functor */ -struct ObjectPtrHash { - size_t operator()(const ObjectRef& a) const { return operator()(a.data_); } - - template - size_t operator()(const ObjectPtr& a) const { - return std::hash()(a.get()); - } -}; - -/*! \brief ObjectRef equal functor */ -struct ObjectPtrEqual { - bool operator()(const ObjectRef& a, const ObjectRef& b) const { return a.same_as(b); } - - template - size_t operator()(const ObjectPtr& a, const ObjectPtr& b) const { - return a == b; - } -}; - -/*! - * \brief helper macro to declare a base object type that can be inherited. - * \param TypeName The name of the current type. - * \param ParentType The name of the ParentType - */ -#define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \ - static_assert(!ParentType::_type_final, "ParentObj marked as final"); \ - static uint32_t RuntimeTypeIndex() { \ - static_assert(TypeName::_type_child_slots == 0 || ParentType::_type_child_slots == 0 || \ - TypeName::_type_child_slots < ParentType::_type_child_slots, \ - "Need to set _type_child_slots when parent specifies it."); \ - if (TypeName::_type_index != ::tvm::runtime::TypeIndex::kDynamic) { \ - return TypeName::_type_index; \ - } \ - return _GetOrAllocRuntimeTypeIndex(); \ - } \ - static uint32_t _GetOrAllocRuntimeTypeIndex() { \ - static uint32_t tindex = Object::GetOrAllocRuntimeTypeIndex( \ - TypeName::_type_key, TypeName::_type_index, ParentType::_GetOrAllocRuntimeTypeIndex(), \ - TypeName::_type_child_slots, TypeName::_type_child_slots_can_overflow); \ - return tindex; \ - } - -/*! - * \brief helper macro to declare type information in a final class. - * \param TypeName The name of the current type. - * \param ParentType The name of the ParentType - */ -#define TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType) \ - static const constexpr bool _type_final = true; \ - static const constexpr int _type_child_slots = 0; \ - TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) - -/*! \brief helper macro to suppress unused warning */ -#if defined(__GNUC__) -#define TVM_ATTRIBUTE_UNUSED __attribute__((unused)) -#else -#define TVM_ATTRIBUTE_UNUSED -#endif - -#define TVM_STR_CONCAT_(__x, __y) __x##__y -#define TVM_STR_CONCAT(__x, __y) TVM_STR_CONCAT_(__x, __y) - -#define TVM_OBJECT_REG_VAR_DEF static TVM_ATTRIBUTE_UNUSED uint32_t __make_Object_tid - -/*! - * \brief Helper macro to register the object type to runtime. - * Makes sure that the runtime type table is correctly populated. - * - * Use this macro in the cc file for each terminal class. - */ -#define TVM_REGISTER_OBJECT_TYPE(TypeName) \ - TVM_STR_CONCAT(TVM_OBJECT_REG_VAR_DEF, __COUNTER__) = TypeName::_GetOrAllocRuntimeTypeIndex() - /* * \brief Define the default copy/move constructor and assign operator * \param TypeName The class typename. @@ -736,75 +84,6 @@ struct ObjectPtrEqual { TypeName& operator=(const TypeName& other) = default; \ TypeName& operator=(TypeName&& other) = default; -/* - * \brief Define object reference methods. - * \param TypeName The object type name - * \param ParentType The parent type of the objectref - * \param ObjectName The type name of the object. - */ -#define TVM_DEFINE_OBJECT_REF_METHODS_WITHOUT_DEFAULT_CONSTRUCTOR(TypeName, ParentType, \ - ObjectName) \ - explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \ - TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ - const ObjectName* operator->() const { return static_cast(data_.get()); } \ - const ObjectName* get() const { return operator->(); } \ - using ContainerType = ObjectName; - -/* - * \brief Define object reference methods. - * \param TypeName The object type name - * \param ParentType The parent type of the objectref - * \param ObjectName The type name of the object. - */ -#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ - TypeName() = default; \ - TVM_DEFINE_OBJECT_REF_METHODS_WITHOUT_DEFAULT_CONSTRUCTOR(TypeName, ParentType, ObjectName) - -/* - * \brief Define object reference methods that is not nullable. - * - * \param TypeName The object type name - * \param ParentType The parent type of the objectref - * \param ObjectName The type name of the object. - */ -#define TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ - explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \ - TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ - const ObjectName* operator->() const { return static_cast(data_.get()); } \ - const ObjectName* get() const { return operator->(); } \ - static constexpr bool _type_is_nullable = false; \ - using ContainerType = ObjectName; - -/* - * \brief Define object reference methods of whose content is mutable. - * \param TypeName The object type name - * \param ParentType The parent type of the objectref - * \param ObjectName The type name of the object. - * \note We recommend making objects immutable when possible. - * This macro is only reserved for objects that stores runtime states. - */ -#define TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ - TypeName() = default; \ - TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ - explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \ - ObjectName* operator->() const { return static_cast(data_.get()); } \ - using ContainerType = ObjectName; - -/* - * \brief Define object reference methods that is both not nullable and mutable. - * - * \param TypeName The object type name - * \param ParentType The parent type of the objectref - * \param ObjectName The type name of the object. - */ -#define TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ - explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \ - TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ - ObjectName* operator->() const { return static_cast(data_.get()); } \ - ObjectName* get() const { return operator->(); } \ - static constexpr bool _type_is_nullable = false; \ - using ContainerType = ObjectName; - /*! * \brief Define CopyOnWrite function in an ObjectRef. * \param ObjectName The Type of the Node. @@ -838,115 +117,38 @@ struct ObjectPtrEqual { return static_cast(data_.get()); \ } -// Implementations details below -// Object reference counting. -#if TVM_OBJECT_ATOMIC_REF_COUNTER - -inline void Object::IncRef() { ref_counter_.fetch_add(1, std::memory_order_relaxed); } - -inline void Object::DecRef() { - if (ref_counter_.fetch_sub(1, std::memory_order_release) == 1) { - std::atomic_thread_fence(std::memory_order_acquire); - if (this->deleter_ != nullptr) { - (*this->deleter_)(this); - } - } -} - -inline int Object::use_count() const { return ref_counter_.load(std::memory_order_relaxed); } - -#else - -inline void Object::IncRef() { ++ref_counter_; } - -inline void Object::DecRef() { - if (--ref_counter_ == 0) { - if (this->deleter_ != nullptr) { - (*this->deleter_)(this); - } - } -} - -inline int Object::use_count() const { return ref_counter_; } - -#endif // TVM_OBJECT_ATOMIC_REF_COUNTER - -template -inline bool Object::IsInstance() const { - const Object* self = this; - // NOTE: the following code can be optimized by - // compiler dead-code elimination for already known constants. - if (self != nullptr) { - // Everything is a subclass of object. - if (std::is_same::value) return true; - if (TargetType::_type_final) { - // if the target type is a final type - // then we only need to check the equivalence. - return self->type_index_ == TargetType::RuntimeTypeIndex(); - } else { - // if target type is a non-leaf type - // Check if type index falls into the range of reserved slots. - uint32_t begin = TargetType::RuntimeTypeIndex(); - // The condition will be optimized by constant-folding. - if (TargetType::_type_child_slots != 0) { - uint32_t end = begin + TargetType::_type_child_slots + 1; - if (self->type_index_ >= begin && self->type_index_ < end) return true; - } else { - if (self->type_index_ == begin) return true; - } - if (!TargetType::_type_child_slots_can_overflow) return false; - // Invariance: parent index is always smaller than the child. - if (self->type_index_ < TargetType::RuntimeTypeIndex()) return false; - // The rare slower-path, check type hierarchy. - return self->DerivedFrom(TargetType::RuntimeTypeIndex()); - } - } else { - return false; - } -} - -inline bool Object::unique() const { return use_count() == 1; } +/* + * \brief Define object reference methods. + * \param TypeName The object type name + * \param ParentType The parent type of the objectref + * \param ObjectName The type name of the object. + */ +#define TVM_DEFINE_OBJECT_REF_METHODS_WITHOUT_DEFAULT_CONSTRUCTOR(TypeName, ParentType, \ + ObjectName) \ + explicit TypeName(::tvm::ffi::ObjectPtr<::tvm::ffi::Object> n) : ParentType(n) {} \ + TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ + const ObjectName* operator->() const { return static_cast(data_.get()); } \ + const ObjectName* get() const { return operator->(); } \ + using ContainerType = ObjectName; -template -inline const ObjectType* ObjectRef::as() const { - if (data_ != nullptr && data_->IsInstance()) { - return static_cast(data_.get()); - } else { - return nullptr; - } -} +#define TVM_DECLARE_BASE_OBJECT_INFO TVM_FFI_DECLARE_BASE_OBJECT_INFO +#define TVM_DECLARE_FINAL_OBJECT_INFO TVM_FFI_DECLARE_FINAL_OBJECT_INFO +#define TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS -template -inline RefType GetRef(const ObjType* ptr) { - static_assert(std::is_base_of::value, - "Can only cast to the ref of same container type"); - if (!RefType::_type_is_nullable) { - ICHECK(ptr != nullptr); - } - return RefType(ObjectPtr(const_cast(static_cast(ptr)))); -} +#define TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS TVM_FFI_DEFINE_MUTABLE_OBJECT_REF_METHODS +#define TVM_DEFINE_OBJECT_REF_METHODS TVM_FFI_DEFINE_OBJECT_REF_METHODS +#define TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS \ + TVM_FFI_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS -template -inline ObjectPtr GetObjectPtr(ObjType* ptr) { - static_assert(std::is_base_of::value, - "Can only cast to the ref of same container type"); - return ObjectPtr(static_cast(ptr)); -} +#define TVM_STR_CONCAT_(__x, __y) __x##__y +#define TVM_STR_CONCAT(__x, __y) TVM_STR_CONCAT_(__x, __y) -template -inline SubRef Downcast(BaseRef ref) { - if (ref.defined()) { - ICHECK(ref->template IsInstance()) - << "Downcast from " << ref->GetTypeKey() << " to " << SubRef::ContainerType::_type_key - << " failed."; - } else { - ICHECK(SubRef::_type_is_nullable) << "Downcast from nullptr to not nullable reference of " - << SubRef::ContainerType::_type_key; - } - return SubRef(std::move(ref.data_)); -} +// Object register type is now a nop +#define TVM_REGISTER_OBJECT_TYPE(x) } // namespace runtime -} // namespace tvm +using tvm::ffi::ObjectPtr; +using tvm::ffi::ObjectRef; +} // namespace tvm #endif // TVM_RUNTIME_OBJECT_H_ diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 07a6848bfeed..3609987d5585 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -24,2676 +24,440 @@ #ifndef TVM_RUNTIME_PACKED_FUNC_H_ #define TVM_RUNTIME_PACKED_FUNC_H_ +#include +#include #include -#include -#include -#include -#include -#include #include #include #include -#include -#include -#include -#include -#include #include -#include -#include #include #include -// Whether use TVM runtime in header only mode. -#ifndef TVM_RUNTIME_HEADER_ONLY -#define TVM_RUNTIME_HEADER_ONLY 0 -#endif - namespace tvm { namespace runtime { -// forward declarations -class TVMArgs; -class TVMArgValue; -class TVMMovableArgValueWithContext_; -class TVMRetValue; -class TVMArgsSetter; -template -class TypedPackedFunc; -template -struct SignaturePrinter; +using ffi::Any; +using ffi::AnyView; /*! - * \brief Object container class that backs PackedFunc. - * \note Do not use this function directly, use PackedFunc. + * \brief Utility function to convert legacy TVMArgValue to AnyView + * \note This routine is not fastest, but serves purpose to do transition of ABI. */ -class PackedFuncObj : public Object { - public: - /*! - * \brief Call the function in packed format. - * \param args The arguments - * \param rv The return value. - */ - TVM_ALWAYS_INLINE void CallPacked(TVMArgs args, TVMRetValue* rv) const; - - static constexpr const uint32_t _type_index = TypeIndex::kRuntimePackedFunc; - static constexpr const char* _type_key = "runtime.PackedFunc"; - TVM_DECLARE_FINAL_OBJECT_INFO(PackedFuncObj, Object); - - protected: - /*! - * \brief Internal struct for extracting the callable method from callable type. - */ - template - struct Extractor { - /*! - * \brief Extracting the callable method from callable type. - * \param obj The base packed function object class. - * \param args The arguments - * \param rv The return value. - */ - static void Call(const PackedFuncObj* obj, TVMArgs args, TVMRetValue* rv); - }; - - /*! \brief The internal callable function type. */ - using FCallPacked = void(const PackedFuncObj*, TVMArgs, TVMRetValue*); - - /*! - * \brief Constructing a packed function object from a function pointer. - * \param f_call_pack The function pointer used to call the packed function. - */ - explicit PackedFuncObj(FCallPacked* f_call_pack) : f_call_packed_(f_call_pack) {} - - /*! \brief Delete the default constructor explicitly. */ - PackedFuncObj() = delete; - - /*! \brief Internal callable function pointer used to call the packed function. */ - FCallPacked* f_call_packed_; -}; - -/*! \brief Derived object class for constructing PackedFuncObj. */ -template -class PackedFuncSubObj : public PackedFuncObj { - using TStorage = typename std::remove_cv::type>::type; +inline TVMFFIAny LegacyTVMArgValueToFFIAny(TVMValue value, int type_code) { + TVMFFIAny res; + // clear first to ensure consistent hash + res.v_uint64 = 0; + switch (type_code) { + case kTVMArgInt: { + res.type_index = ffi::TypeIndex::kTVMFFIInt; + res.v_int64 = value.v_int64; + return res; + } + case kTVMArgFloat: { + res.type_index = ffi::TypeIndex::kTVMFFIFloat; + res.v_float64 = value.v_float64; + return res; + } + case kTVMOpaqueHandle: { + res.type_index = ffi::TypeIndex::kTVMFFIOpaquePtr; + res.v_ptr = value.v_handle; + return res; + } + case kTVMNullptr: { + res.type_index = ffi::TypeIndex::kTVMFFINone; + return res; + } + case kTVMDataType: { + res.type_index = ffi::TypeIndex::kTVMFFIDataType; + res.v_dtype = value.v_type; + return res; + } + case kDLDevice: { + res.type_index = ffi::TypeIndex::kTVMFFIDevice; + res.v_device = value.v_device; + return res; + } + case kTVMDLTensorHandle: { + res.type_index = ffi::TypeIndex::kTVMFFIDLTensorPtr; + res.v_ptr = value.v_handle; + return res; + } + case kTVMObjectHandle: { + res.v_obj = static_cast(value.v_handle); + res.type_index = res.v_obj->type_index; + return res; + } + case kTVMModuleHandle: { + res.type_index = ffi::TypeIndex::kTVMFFIModule; + res.v_obj = static_cast(value.v_handle); + return res; + } + case kTVMPackedFuncHandle: { + res.type_index = ffi::TypeIndex::kTVMFFIFunction; + res.v_obj = static_cast(value.v_handle); + return res; + } + case kTVMStr: { + res.type_index = ffi::TypeIndex::kTVMFFIRawStr; + res.v_c_str = value.v_str; + return res; + } + case kTVMBytes: { + res.type_index = ffi::TypeIndex::kTVMFFIByteArrayPtr; + res.v_ptr = value.v_handle; + return res; + } + case kTVMNDArrayHandle: { + res.type_index = ffi::TypeIndex::kTVMFFINDArray; + res.v_obj = reinterpret_cast(TVMArrayHandleToObjectHandle(value.v_handle)); + return res; + } + case kTVMArgBool: { + res.type_index = ffi::TypeIndex::kTVMFFIBool; + res.v_int64 = value.v_int64; + return res; + } + case kTVMObjectRValueRefArg: { + res.type_index = ffi::TypeIndex::kTVMFFIObjectRValueRef; + res.v_ptr = value.v_handle; + return res; + } + default: { + LOG(FATAL) << "Unsupported type code: " << type_code; + TVM_FFI_UNREACHABLE(); + } + } +} - public: - /*! \brief The type of derived object class */ - using TSelf = PackedFuncSubObj; - /*! - * \brief Derived object class for constructing PackedFuncObj. - * \param callable The type-erased callable object. - */ - explicit PackedFuncSubObj(TCallable callable) - : PackedFuncObj(Extractor::Call), callable_(callable) {} - /*! \brief Type-erased filed for storing callable object*/ - mutable TStorage callable_; -}; +/*! + * \brief Utility function to convert legacy TVMArgValue to AnyView + * \note This routine is not fastest, but serves purpose to do transition of ABI. + */ +inline AnyView LegacyTVMArgValueToAnyView(TVMValue value, int type_code) { + return AnyView::CopyFromTVMFFIAny(LegacyTVMArgValueToFFIAny(value, type_code)); +} /*! - * \brief Packed function is a type-erased function. - * The arguments are passed by packed format. - * - * This is an useful unified interface to call generated functions, - * It is the unified function type of TVM. - * It corresponds to TVMFunctionHandle in C runtime API. + * \brief Utility function to convert legacy TVMArgValue to Any + * \note This routine is not fastest, but serves purpose to do transition of ABI. */ -class PackedFunc : public ObjectRef { - public: - /*! \brief Constructor from null */ - PackedFunc(std::nullptr_t null) : ObjectRef(nullptr) {} // NOLINT(*) - /*! - * \brief Constructing a packed function from a callable type - * whose signature is consistent with `PackedFunc` - * \param data the internal container of packed function. - */ - template >::value && - !std::is_base_of::value>> - explicit PackedFunc(TCallable data) { - using ObjType = PackedFuncSubObj; - data_ = make_object(std::forward(data)); - } - /*! - * \brief Call packed function by directly passing in unpacked format. - * \param args Arguments to be passed. - * \tparam Args arguments to be passed. - * - * \code - * // Example code on how to call packed function - * void CallPacked(PackedFunc f) { - * // call like normal functions by pass in arguments - * // return value is automatically converted back - * int rvalue = f(1, 2.0); - * } - * \endcode - */ - template - inline TVMRetValue operator()(Args&&... args) const; - /*! - * \brief Call the function in packed format. - * \param args The arguments - * \param rv The return value. - */ - TVM_ALWAYS_INLINE void CallPacked(TVMArgs args, TVMRetValue* rv) const; - /*! \return Whether the packed function is nullptr */ - bool operator==(std::nullptr_t null) const { return data_ == nullptr; } - /*! \return Whether the packed function is not nullptr */ - bool operator!=(std::nullptr_t null) const { return data_ != nullptr; } +inline Any MoveLegacyTVMArgValueToAny(TVMValue value, int type_code) { + return ffi::details::AnyUnsafe::MoveTVMFFIAnyToAny(LegacyTVMArgValueToFFIAny(value, type_code)); +} - TVM_DEFINE_OBJECT_REF_METHODS(PackedFunc, ObjectRef, PackedFuncObj); -}; +/* + * \brief Convert AnyView to legacy TVMValue and type_code + * \param src The AnyView to convert + * \param value The TVMValue to store the result + * \param type_code The type code to store the result + * \note This routine is not fastest, but serves purpose to do transition of ABI. + */ +inline void AnyViewToLegacyTVMArgValue(TVMFFIAny src, TVMValue* value, int* type_code) { + switch (src.type_index) { + case ffi::TypeIndex::kTVMFFIBool: { + type_code[0] = kTVMArgBool; + value[0].v_int64 = src.v_int64; + break; + } + case ffi::TypeIndex::kTVMFFIInt: { + type_code[0] = kDLInt; + value[0].v_int64 = src.v_int64; + break; + } + case ffi::TypeIndex::kTVMFFIFloat: { + type_code[0] = kDLFloat; + value[0].v_float64 = src.v_float64; + break; + } + case ffi::TypeIndex::kTVMFFIOpaquePtr: { + type_code[0] = kTVMOpaqueHandle; + value[0].v_handle = src.v_ptr; + break; + } + case ffi::TypeIndex::kTVMFFINone: { + type_code[0] = kTVMNullptr; + break; + } + case ffi::TypeIndex::kTVMFFIDataType: { + type_code[0] = kTVMDataType; + value[0].v_type = src.v_dtype; + break; + } + case ffi::TypeIndex::kTVMFFIDevice: { + type_code[0] = kDLDevice; + value[0].v_device = src.v_device; + break; + } + case ffi::TypeIndex::kTVMFFIDLTensorPtr: { + type_code[0] = kTVMDLTensorHandle; + value[0].v_handle = src.v_ptr; + break; + } + case ffi::TypeIndex::kTVMFFIRawStr: { + type_code[0] = kTVMStr; + value[0].v_str = src.v_c_str; + break; + } + case ffi::TypeIndex::kTVMFFIByteArrayPtr: { + type_code[0] = kTVMBytes; + value[0].v_handle = src.v_ptr; + break; + } + case ffi::TypeIndex::kTVMFFINDArray: { + type_code[0] = kTVMNDArrayHandle; + value[0].v_handle = ObjectHandleToTVMArrayHandle(reinterpret_cast(src.v_obj)); + break; + } + case ffi::TypeIndex::kTVMFFIModule: { + type_code[0] = kTVMModuleHandle; + value[0].v_handle = src.v_obj; + break; + } + case ffi::TypeIndex::kTVMFFIFunction: { + type_code[0] = kTVMPackedFuncHandle; + value[0].v_handle = src.v_obj; + break; + } + case ffi::TypeIndex::kTVMFFIObjectRValueRef: { + type_code[0] = kTVMObjectRValueRefArg; + value[0].v_handle = src.v_ptr; + break; + } + default: { + if (src.type_index >= ffi::TypeIndex::kTVMFFIStaticObjectBegin) { + type_code[0] = kTVMObjectHandle; + value[0].v_handle = src.v_obj; + break; + } + LOG(FATAL) << "Unsupported type index: " << src.type_index; + } + } +} -/*! \brief Using static function to output TypedPackedFunc signature */ -using FSig = std::string(); +/* + * \brief Move Any to legacy TVMValue and type_code + * \param src The Any to move + * \param value The TVMValue to store the result + * \param type_code The type code to store the result + */ +inline void MoveAnyToLegacyTVMValue(Any&& src, TVMValue* value, int* type_code) { + TVMFFIAny val = ffi::details::AnyUnsafe::MoveAnyToTVMFFIAny(std::move(src)); + // NOTE: conversion rule is the same as AnyViewToLegacyTVMArgValue + AnyViewToLegacyTVMArgValue(val, value, type_code); +} /*! - * \brief Please refer to \ref TypedPackedFuncAnchor "TypedPackedFunc" + * \brief Translate legacy TVMArgs to PackedArgs + * \param value The TVMValue array + * \param type_code The type code array + * \param num_args The number of arguments + * \param dst The destination AnyView array */ -template -class TypedPackedFunc; +inline void LegacyTVMArgsToPackedArgs(const TVMValue* value, const int* type_code, int num_args, + AnyView* dst) { + for (int i = 0; i < num_args; ++i) { + dst[i] = LegacyTVMArgValueToAnyView(value[i], type_code[i]); + } +} /*! - * \anchor TypedPackedFuncAnchor - * \brief A PackedFunc wrapper to provide typed function signature. - * It is backed by a PackedFunc internally. - * - * TypedPackedFunc enables compile time type checking. - * TypedPackedFunc works with the runtime system: - * - It can be passed as an argument of PackedFunc. - * - It can be assigned to TVMRetValue. - * - It can be directly converted to a type-erased PackedFunc. - * - * Developers should prefer TypedPackedFunc over PackedFunc in C++ code - * as it enables compile time checking. - * We can construct a TypedPackedFunc from a lambda function - * with the same signature. - * - * \code - * // user defined lambda function. - * auto addone = [](int x)->int { - * return x + 1; - * }; - * // We can directly convert - * // lambda function to TypedPackedFunc - * TypedPackedFunc ftyped(addone); - * // invoke the function. - * int y = ftyped(1); - * // Can be directly converted to PackedFunc - * PackedFunc packed = ftype; - * \endcode - * \tparam R The return value of the function. - * \tparam Args The argument signature of the function. + * \brief Translate legacy TVMArgs to PackedArgs + * \param args The AnyView array + * \param num_args The number of arguments + * \param value The TVMValue array + * \param type_code The type code array */ -template -class TypedPackedFunc { - public: - /*! \brief short hand for this function type */ - using TSelf = TypedPackedFunc; - /*! \brief default constructor */ - TypedPackedFunc() {} - /*! \brief constructor from null */ - TypedPackedFunc(std::nullptr_t null) {} // NOLINT(*) - /*! - * \brief construct by wrap a PackedFunc - * - * Example usage: - * \code - * PackedFunc packed([](TVMArgs args, TVMRetValue *rv) { - * int x = args[0]; - * *rv = x + 1; - * }); - * // construct from packed function - * TypedPackedFunc ftyped(packed); - * // call the typed version. - * ICHECK_EQ(ftyped(1), 2); - * \endcode - * - * \param packed The packed function - */ - inline TypedPackedFunc(PackedFunc packed); // NOLINT(*) - /*! - * \brief constructor from TVMRetValue - * \param value The TVMRetValue - */ - inline TypedPackedFunc(const TVMRetValue& value); // NOLINT(*) - /*! - * \brief constructor from TVMArgValue - * \param value The TVMArgValue - */ - inline TypedPackedFunc(const TVMArgValue& value); // NOLINT(*) - /*! - * \brief constructor from TVMMovableArgValue_ - * \param value The TVMMovableArgValue_ - */ - inline TypedPackedFunc(TVMMovableArgValueWithContext_&& value); // NOLINT(*) - /*! - * \brief construct from a lambda function with the same signature. - * - * Example usage: - * \code - * auto typed_lambda = [](int x)->int { return x + 1; } - * // construct from packed function - * TypedPackedFunc ftyped(typed_lambda, "add_one"); - * // call the typed version. - * ICHECK_EQ(ftyped(1), 2); - * \endcode - * - * \param typed_lambda typed lambda function. - * \param name the name of the lambda function. - * \tparam FLambda the type of the lambda function. - */ - template >::value>::type> - TypedPackedFunc(const FLambda& typed_lambda, std::string name) { // NOLINT(*) - this->AssignTypedLambda(typed_lambda, name); +inline void PackedArgsToLegacyTVMArgs(const AnyView* args, int num_args, TVMValue* value, + int* type_code) { + for (int i = 0; i < num_args; ++i) { + AnyViewToLegacyTVMArgValue(args[i].CopyToTVMFFIAny(), value + i, type_code + i); } - /*! - * \brief construct from a lambda function with the same signature. - * - * This version does not take a name. It is highly recommend you use the - * version that takes a name for the lambda. - * - * Example usage: - * \code - * auto typed_lambda = [](int x)->int { return x + 1; } - * // construct from packed function - * TypedPackedFunc ftyped(typed_lambda); - * // call the typed version. - * ICHECK_EQ(ftyped(1), 2); - * \endcode - * - * \param typed_lambda typed lambda function. - * \tparam FLambda the type of the lambda function. - */ - template >::value>::type> - TypedPackedFunc(const FLambda& typed_lambda) { // NOLINT(*) - this->AssignTypedLambda(typed_lambda); - } - /*! - * \brief copy assignment operator from typed lambda - * - * Example usage: - * \code - * // construct from packed function - * TypedPackedFunc ftyped; - * ftyped = [](int x) { return x + 1; } - * // call the typed version. - * ICHECK_EQ(ftyped(1), 2); - * \endcode - * - * \param typed_lambda typed lambda function. - * \tparam FLambda the type of the lambda function. - * \returns reference to self. - */ - template >::value>::type> - TSelf& operator=(FLambda typed_lambda) { // NOLINT(*) - this->AssignTypedLambda(typed_lambda); - return *this; - } - /*! - * \brief copy assignment operator from PackedFunc. - * \param packed The packed function. - * \returns reference to self. - */ - TSelf& operator=(PackedFunc packed) { - packed_ = packed; - return *this; - } - /*! - * \brief Invoke the operator. - * \param args The arguments - * \returns The return value. - */ - TVM_ALWAYS_INLINE R operator()(Args... args) const; - /*! - * \brief convert to PackedFunc - * \return the internal PackedFunc - */ - operator PackedFunc() const { return packed(); } - /*! - * \return reference the internal PackedFunc - */ - const PackedFunc& packed() const { return packed_; } - /*! \return Whether the packed function is nullptr */ - bool operator==(std::nullptr_t null) const { return packed_ == nullptr; } - /*! \return Whether the packed function is not nullptr */ - bool operator!=(std::nullptr_t null) const { return packed_ != nullptr; } +} - private: - friend class TVMRetValue; - /*! \brief The internal packed function */ - PackedFunc packed_; - /*! - * \brief Assign the packed field using a typed lambda function. - * - * \param flambda The lambda function. - * \param name The name associated with this lambda. - * \tparam FLambda The lambda function type. - * \note We capture the lambda when possible for maximum efficiency. - */ - template - inline void AssignTypedLambda(FLambda flambda, std::string name); - /*! - * \brief Assign the packed field using a typed lambda function. This variant is for functions - * without names. - * - * \param flambda The lambda function. - * \tparam FLambda The lambda function type. - * \note We capture the lambda when possible for maximum efficiency. - */ - template - inline void AssignTypedLambda(FLambda flambda); -}; +// redirect to ffi::PackedArgs +using TVMArgs = ffi::PackedArgs; +// redirect to ffi::AnyView and ffi::Any for ArgValue and RetValue +using TVMArgValue = ffi::AnyView; +using TVMRetValue = ffi::Any; -/*! \brief Arguments into TVM functions. */ -class TVMArgs { - public: - const TVMValue* values; - const int* type_codes; - int num_args; - /*! - * \brief constructor - * \param values The argument values - * \param type_codes The argument type codes - * \param num_args number of arguments. - */ - TVMArgs(const TVMValue* values, const int* type_codes, int num_args) - : values(values), type_codes(type_codes), num_args(num_args) {} - /*! \return size of the arguments */ - inline int size() const; - /*! - * \brief Get i-th argument - * \param i the index. - * \return the ith argument. - */ - inline TVMArgValue operator[](int i) const; - /*! - * \brief Get the i-th argument and do proper type checking with detailed error messages. - * \tparam T The expected type. - * \param i The index - * \return The corresponding argument value. - */ - template - inline T At(int i) const; -}; +// redirect to ffi::Function +using PackedFunc = ffi::Function; + +template +using TypedPackedFunc = ffi::TypedFunction; /*! * \brief Convert argument type code to string. * \param type_code The input type code. * \return The corresponding string repr. */ -inline const char* ArgTypeCode2Str(int type_code); - -inline std::ostream& operator<<(std::ostream& os, DLDevice dev); // NOLINT(*) - -#define TVM_LOG_INCORRECT_TYPE_CODE(CODE, T) \ - "expected " << ArgTypeCode2Str(T) << " but got " << ArgTypeCode2Str(CODE) - -// macro to check type code. -#define TVM_CHECK_TYPE_CODE(CODE, T) ICHECK_EQ(CODE, T) << TVM_LOG_INCORRECT_TYPE_CODE(CODE, T) - -/*! - * \brief Type traits for runtime type check during FFI conversion. - * \tparam T the type to be checked. - */ -template -struct ObjectTypeChecker { - /*! - * \brief Check if an object matches the template type and return the - * mismatched type if it exists. - * \param ptr The object to check the type of. - * \return An Optional containing the actual type of the pointer if it does not match the - * template type. If the Optional does not contain a value, then the types match. - */ - static Optional CheckAndGetMismatch(const Object* ptr) { - using ContainerType = typename T::ContainerType; - if (ptr == nullptr) { - if (T::_type_is_nullable) { - return NullOpt; - } else { - return String("nullptr"); - } - } - if (ptr->IsInstance()) { - return NullOpt; - } else { - return String(ptr->GetTypeKey()); - } - } - /*! - * \brief Check if an object matches the template type. - * \param ptr The object to check the type of. - * \return Whether or not the template type matches the objects type. - */ - static bool Check(const Object* ptr) { - using ContainerType = typename T::ContainerType; - if (ptr == nullptr) return T::_type_is_nullable; - return ptr->IsInstance(); - } - static std::string TypeName() { - using ContainerType = typename T::ContainerType; - return ContainerType::_type_key; +inline const char* ArgTypeCode2Str(int type_code) { + switch (type_code) { + case kDLInt: + return "int"; + case kTVMArgBool: + return "bool"; + case kDLUInt: + return "uint"; + case kDLFloat: + return "float"; + case kTVMStr: + return "str"; + case kTVMBytes: + return "bytes"; + case kTVMOpaqueHandle: + return "handle"; + case kTVMNullptr: + return "NULL"; + case kTVMDLTensorHandle: + return "ArrayHandle"; + case kTVMDataType: + return "DLDataType"; + case kDLDevice: + return "DLDevice"; + case kTVMPackedFuncHandle: + return "FunctionHandle"; + case kTVMModuleHandle: + return "ModuleHandle"; + case kTVMNDArrayHandle: + return "NDArrayContainer"; + case kTVMObjectHandle: + return "Object"; + case kTVMObjectRValueRefArg: + return "ObjectRValueRefArg"; + default: + LOG(FATAL) << "unknown type_code=" << static_cast(type_code); } -}; - -// Additional overloads for PackedFunc checking. -template -struct ObjectTypeChecker> { - static Optional CheckAndGetMismatch(const Object* ptr) { - if (ptr == nullptr) { - return NullOpt; - } - if (!ptr->IsInstance()) { - return String(ptr->GetTypeKey()); - } + throw; +} - if constexpr (std::is_same_v) { - return NullOpt; - } +namespace details { - const ArrayNode* n = static_cast(ptr); - for (size_t i = 0; i < n->size(); i++) { - const ObjectRef& p = (*n)[i]; - Optional check_subtype = ObjectTypeChecker::CheckAndGetMismatch(p.get()); - if (check_subtype.defined()) { - return String("Array[index " + std::to_string(i) + ": " + check_subtype.value() + "]"); - } - } - return NullOpt; - } - static bool Check(const Object* ptr) { - if (ptr == nullptr) return true; - if (!ptr->IsInstance()) return false; - if constexpr (std::is_same_v) return true; +template +struct ModuleVTableEntryHelper {}; - const ArrayNode* n = static_cast(ptr); - for (const ObjectRef& p : *n) { - if (!ObjectTypeChecker::Check(p.get())) { - return false; - } - } - return true; +template +struct ModuleVTableEntryHelper { + using MemFnType = R (T::*)(Args...) const; + static TVM_ALWAYS_INLINE void Call(TVMRetValue* rv, T* self, MemFnType f, TVMArgs args) { + auto wrapped = [self, f](Args... args) -> R { return (self->*f)(std::forward(args)...); }; + ffi::details::unpack_call(std::make_index_sequence{}, nullptr, wrapped, + args.data(), args.size(), rv); } - static std::string TypeName() { return "Array[" + ObjectTypeChecker::TypeName() + "]"; } }; -template -struct ObjectTypeChecker> { - static Optional CheckAndGetMismatch(const Object* ptr) { - if (ptr == nullptr) return NullOpt; - if (!ptr->IsInstance()) return String(ptr->GetTypeKey()); - - if constexpr (std::is_same_v && std::is_same_v) { - return NullOpt; - } - - const MapNode* n = static_cast(ptr); - for (const auto& kv : *n) { - Optional key_type = NullOpt; - if constexpr (!std::is_same_v) { - key_type = ObjectTypeChecker::CheckAndGetMismatch(kv.first.get()); - } - Optional value_type = NullOpt; - if constexpr (!std::is_same_v) { - value_type = ObjectTypeChecker::CheckAndGetMismatch(kv.first.get()); - } - if (key_type.defined() || value_type.defined()) { - std::string key_name = - key_type.defined() ? std::string(key_type.value()) : ObjectTypeChecker::TypeName(); - std::string value_name = value_type.defined() ? std::string(value_type.value()) - : ObjectTypeChecker::TypeName(); - return String("Map[" + key_name + ", " + value_name + "]"); - } - } - return NullOpt; - } - static bool Check(const Object* ptr) { - if (ptr == nullptr) return true; - if (!ptr->IsInstance()) return false; - - if constexpr (std::is_same_v && std::is_same_v) { - return true; - } - - const MapNode* n = static_cast(ptr); - for (const auto& kv : *n) { - if constexpr (!std::is_same_v) { - if (!ObjectTypeChecker::Check(kv.first.get())) return false; - } - if constexpr (!std::is_same_v) { - if (!ObjectTypeChecker::Check(kv.second.get())) return false; - } - } - return true; - } - static std::string TypeName() { - return "Map[" + ObjectTypeChecker::TypeName() + ", " + ObjectTypeChecker::TypeName() + - ']'; +template +struct ModuleVTableEntryHelper { + using MemFnType = R (T::*)(Args...); + static TVM_ALWAYS_INLINE void Call(TVMRetValue* rv, T* self, MemFnType f, TVMArgs args) { + auto wrapped = [self, f](Args... args) -> R { return (self->*f)(std::forward(args)...); }; + ffi::details::unpack_call(std::make_index_sequence{}, nullptr, wrapped, + args.data(), args.size(), rv); } }; -template -struct ObjectTypeChecker> { - static Optional CheckAndGetMismatch(const Object* ptr) { - return ObjectTypeChecker::CheckAndGetMismatch(ptr); +template +struct ModuleVTableEntryHelper { + using MemFnType = void (T::*)(Args...) const; + static TVM_ALWAYS_INLINE void Call(TVMRetValue* rv, T* self, MemFnType f, TVMArgs args) { + auto wrapped = [self, f](Args... args) -> void { (self->*f)(std::forward(args)...); }; + ffi::details::unpack_call(std::make_index_sequence{}, nullptr, wrapped, + args.data(), args.size(), rv); } - static bool Check(const Object* ptr) { return ObjectTypeChecker::Check(ptr); } - static std::string TypeName() { return "Variant[" + VariantNames() + "]"; } - static std::string VariantNames() { return ObjectTypeChecker::TypeName(); } }; -template -struct ObjectTypeChecker> { - static Optional CheckAndGetMismatch(const Object* ptr) { - auto try_first = ObjectTypeChecker::CheckAndGetMismatch(ptr); - if (!try_first.defined()) { - return try_first; - } - - return ObjectTypeChecker>::CheckAndGetMismatch(ptr); - } - static bool Check(const Object* ptr) { - return ObjectTypeChecker::Check(ptr) || - ObjectTypeChecker>::Check(ptr); - } - static std::string TypeName() { return "Variant[" + VariantNames() + "]"; } - static std::string VariantNames() { - return ObjectTypeChecker::TypeName() + ", " + - ObjectTypeChecker>::VariantNames(); +template +struct ModuleVTableEntryHelper { + using MemFnType = void (T::*)(Args...); + static TVM_ALWAYS_INLINE void Call(TVMRetValue* rv, T* self, MemFnType f, TVMArgs args) { + auto wrapped = [self, f](Args... args) -> void { (self->*f)(std::forward(args)...); }; + ffi::details::unpack_call(std::make_index_sequence{}, nullptr, wrapped, + args.data(), args.size(), rv); } }; +} // namespace details -/*! - * \brief Internal base class to - * handle conversion to POD values. - */ -class TVMPODValue_ { - public: - operator void*() const { - if (type_code_ == kTVMNullptr) return nullptr; - if (type_code_ == kTVMDLTensorHandle) return value_.v_handle; - TVM_CHECK_TYPE_CODE(type_code_, kTVMOpaqueHandle); - return value_.v_handle; - } - operator DLTensor*() const { - if (type_code_ == kTVMDLTensorHandle || type_code_ == kTVMNDArrayHandle) { - return static_cast(value_.v_handle); - } else { - if (type_code_ == kTVMNullptr) return nullptr; - LOG(FATAL) << "Expected " - << "DLTensor* or NDArray but got " << ArgTypeCode2Str(type_code_); - return nullptr; - } - } - operator NDArray() const { - if (type_code_ == kTVMNullptr) return NDArray(ObjectPtr(nullptr)); - TVM_CHECK_TYPE_CODE(type_code_, kTVMNDArrayHandle); - return NDArray(NDArray::FFIDataFromHandle(static_cast(value_.v_handle))); - } - operator Module() const { - if (type_code_ == kTVMNullptr) { - return Module(ObjectPtr(nullptr)); - } - TVM_CHECK_TYPE_CODE(type_code_, kTVMModuleHandle); - return Module(ObjectPtr(static_cast(value_.v_handle))); - } - operator PackedFunc() const { - if (type_code_ == kTVMNullptr) { - return PackedFunc(ObjectPtr(nullptr)); - } - TVM_CHECK_TYPE_CODE(type_code_, kTVMPackedFuncHandle); - return PackedFunc(ObjectPtr(static_cast(value_.v_handle))); - } - operator Device() const { - TVM_CHECK_TYPE_CODE(type_code_, kDLDevice); - return value_.v_device; - } - int type_code() const { return type_code_; } - /*! - * \brief return handle as specific pointer type. - * \tparam T the data type. - * \return The pointer type. - */ - template - T* ptr() const { - return static_cast(value_.v_handle); - } - - std::optional TryAsBool() const { - // Helper function to reduce duplication in the variable integer - // conversions. This is publicly exposed, as it can be useful in - // specializations of PackedFuncValueConverter. - if (type_code_ == kTVMArgBool) { - return static_cast(value_.v_int64); - } else { - return std::nullopt; - } +#define TVM_MODULE_VTABLE_BEGIN(TypeKey) \ + const char* type_key() const final { return TypeKey; } \ + PackedFunc GetFunction(const String& _name, const ObjectPtr& _self) override { \ + using SelfPtr = std::remove_cv_t; +#define TVM_MODULE_VTABLE_END() \ + return PackedFunc(nullptr); \ } - - std::optional TryAsInt() const { - // Helper function to reduce duplication in the variable integer - // conversions. This is publicly exposed, as it can be useful in - // specializations of PackedFuncValueConverter. - if (type_code_ == kDLInt) { - return value_.v_int64; - } else { - return std::nullopt; - } +#define TVM_MODULE_VTABLE_END_WITH_DEFAULT(MemFunc) \ + { \ + auto f = (MemFunc); \ + return (this->*f)(_name); \ + } \ + } // NOLINT(*) +#define TVM_MODULE_VTABLE_ENTRY(Name, MemFunc) \ + if (_name == Name) { \ + return ffi::Function::FromPacked([_self](ffi::PackedArgs args, Any* rv) -> void { \ + using Helper = ::tvm::runtime::details::ModuleVTableEntryHelper; \ + SelfPtr self = static_cast(_self.get()); \ + Helper::Call(rv, self, MemFunc, args); \ + }); \ } - - std::optional TryAsFloat() const { - // Helper function to reduce duplication in the variable integer - // conversions. This is publicly exposed, as it can be useful in - // specializations of PackedFuncValueConverter. - if (type_code_ == kDLFloat) { - return value_.v_float64; - } else { - return std::nullopt; - } +#define TVM_MODULE_VTABLE_ENTRY_PACKED(Name, MemFunc) \ + if (_name == Name) { \ + return PackedFunc([_self](ffi::PackedArgs args, Any* rv) -> void { \ + (static_cast(_self.get())->*(MemFunc))(args, rv); \ + }); \ } - protected: - friend class TVMArgsSetter; - friend class TVMRetValue; - friend class TVMMovableArgValue_; - TVMPODValue_() : type_code_(kTVMNullptr) {} - TVMPODValue_(TVMValue value, int type_code) : value_(value), type_code_(type_code) {} - - /*! \brief The value */ - TVMValue value_; - /*! \brief the type code */ - int type_code_; -}; - -/*! \brief A utility class that adds methods useful for each POD type - * - * These cannot be provided in the base PODValue_ class, because - * TVMArgValue and TVMRetValue have different semantics for kTVMStr - * and kTVMBytes. - * - * kTVMStr: +/*! + * \brief Export typed function as a PackedFunc + * that can be loaded by LibraryModule. * - * For `TVMArgValue`, the active variant is `v_str`, a `const - * char*`. For `TVMRetValue`, the active variant is `v_handle`, - * and should be cast from `void*` to `std::string*`. + * \param ExportName The symbol name to be exported. + * \param Function The typed function. + * \note ExportName and Function must be different, + * see code examples below. * - * kTVMBytes: + * \sa TypedPackedFunc * - * The active variant is `v_handle`, a `void*`. For - * `TVMArgValue`, should be cast to `TVMByteArray*`. For - * `TVMRetValue`, should be cast to `std::string*`. + * \code * - * When converting into an `ObjectRef`, a string may be used to build - * a `tvm::runtime::String`. Because TVMArgValue and TVMRetValue use - * different representations for strings, any utility funciton which - * might attempt a conversion to an `ObjectRef` must be performed - * within a context that is aware of the derived class. - */ -template -class TVMPODValue_CRTP_ : public TVMPODValue_ { - public: - using TVMPODValue_::TVMPODValue_; - - // ObjectRef handling - template ::value>::type> - inline bool IsObjectRef() const; - template - inline TObjectRef AsObjectRef() const; - - operator double() const { - // Allow automatic conversion from int to float - // This avoids errors when user pass in int from - // the frontend while the API expects a float. - if (auto opt = TryAsFloat()) { - return opt.value(); - } else if (auto opt = TryAsInt()) { - return opt.value(); - } else if (auto opt = TryAsBool()) { - return opt.value(); - } else { - LOG(FATAL) << TVM_LOG_INCORRECT_TYPE_CODE(type_code_, kDLFloat); - } - } - operator int64_t() const { - if (auto opt = TryAsInt()) { - return opt.value(); - } else if (auto opt = TryAsBool()) { - return opt.value(); - } else { - LOG(FATAL) << TVM_LOG_INCORRECT_TYPE_CODE(type_code_, kDLInt); - } - } - operator uint64_t() const { return operator int64_t(); } - operator int() const { - int64_t value = operator int64_t(); - ICHECK_LE(value, std::numeric_limits::max()); - ICHECK_GE(value, std::numeric_limits::min()); - return value; - } - operator bool() const { - if (auto opt = TryAsBool()) { - return opt.value(); - } else if (auto opt = TryAsInt()) { - return opt.value(); - } else { - LOG(FATAL) << TVM_LOG_INCORRECT_TYPE_CODE(type_code_, kDLInt); - } - } -}; - -/*! - * \brief A single argument value to PackedFunc. - * Containing both type_code and TVMValue + * int AddOne_(int x) { + * return x + 1; + * } * - * Provides utilities to do type cast into other types. - */ -class TVMArgValue : public TVMPODValue_CRTP_ { - public: - /*! \brief default constructor */ - TVMArgValue() {} - /*! - * \brief constructor - * \param value of the function - * \param type_code The type code. - */ - TVMArgValue(TVMValue value, int type_code) : TVMPODValue_CRTP_(value, type_code) {} - // reuse converter from parent - using TVMPODValue_CRTP_::operator double; - using TVMPODValue_CRTP_::operator int64_t; - using TVMPODValue_CRTP_::operator uint64_t; - using TVMPODValue_CRTP_::operator int; - using TVMPODValue_CRTP_::operator bool; - using TVMPODValue_::operator void*; - using TVMPODValue_::operator DLTensor*; - using TVMPODValue_::operator NDArray; - using TVMPODValue_::operator Device; - using TVMPODValue_::operator Module; - using TVMPODValue_::operator PackedFunc; - using TVMPODValue_CRTP_::AsObjectRef; - using TVMPODValue_CRTP_::IsObjectRef; - - // conversion operator. - operator std::string() const { - if (type_code_ == kTVMDataType) { - return DLDataType2String(operator DLDataType()); - } else if (type_code_ == kTVMBytes) { - TVMByteArray* arr = static_cast(value_.v_handle); - return std::string(arr->data, arr->size); - } else if (type_code_ == kTVMStr) { - return std::string(value_.v_str); - } else { - return AsObjectRef().operator std::string(); - } - } - template - operator TypedPackedFunc() const { - return TypedPackedFunc(operator PackedFunc()); - } - const TVMValue& value() const { return value_; } - - template ::value>::type> - inline operator T() const; - inline operator DLDataType() const; - inline operator DataType() const; -}; - -/*! - * \brief Internal auxiliary struct for TypedPackedFunc to indicate a movable argument. + * // Expose the function as "AddOne" + * TVM_DLL_EXPORT_TYPED_FUNC(AddOne, AddOne_); * - * We can only construct a movable argument once from a single argument position. - * If the argument is passed as RValue reference, the result will be moved. - * We should only construct a MovableArg from an argument once, - * as the result will can moved. + * // Expose the function as "SubOne" + * TVM_DLL_EXPORT_TYPED_FUNC(SubOne, [](int x) { + * return x - 1; + * }); * - * \note For internal development purpose only. - */ -class TVMMovableArgValue_ : public TVMPODValue_CRTP_ { - public: - TVMMovableArgValue_(TVMValue value, int type_code) : TVMPODValue_CRTP_(value, type_code) {} - // reuse converter from parent - using TVMPODValue_CRTP_::operator double; - using TVMPODValue_CRTP_::operator int64_t; - using TVMPODValue_CRTP_::operator uint64_t; - using TVMPODValue_CRTP_::operator int; - using TVMPODValue_CRTP_::operator bool; - using TVMPODValue_::operator void*; - using TVMPODValue_::operator DLTensor*; - using TVMPODValue_::operator NDArray; - using TVMPODValue_::operator Device; - using TVMPODValue_::operator Module; - using TVMPODValue_::operator PackedFunc; - // reuse conversion rule from ArgValue. - operator std::string() const { return AsArgValue().operator std::string(); } - template - operator TypedPackedFunc() const { - return TypedPackedFunc(operator PackedFunc()); - } - operator DLDataType() const { return AsArgValue().operator DLDataType(); } - operator DataType() const { return AsArgValue().operator DataType(); } - operator TVMArgValue() const { return AsArgValue(); } - /*! - * \brief Helper converter function. - * Try to move out an argument if possible, - * fall back to normal argument conversion rule otherwise. - */ - template ::value>::type> - inline operator T() const; - - private: - /*! \return The arg value repr of the value. */ - TVMArgValue AsArgValue() const { return TVMArgValue(value_, type_code_); } -}; - -/*! - * \brief Internal auxiliary struct for TypedPackedFunc to indicate a movable argument with - * additional context information (function name and argument index) for better error reporting. + * // The following code will cause compilation error. + * // Because the same Function and ExportName + * // TVM_DLL_EXPORT_TYPED_FUNC(AddOne_, AddOne_); * - * \sa MovableArgValue_ - * \note For internal development purpose only. - */ -class TVMMovableArgValueWithContext_ { - public: - /*! - * \brief move constructor from another return value. - * \param value The other return value. - * \param type_code The code associated with the type of the value. - * \param arg_index In a function call, this argument is at index arg_index (0-indexed). - * \param optional_name Name of the function being called. Can be nullptr if the function is not. - * \param f_sig Pointer to static function outputting signature of the function being called. - * named. - */ - TVMMovableArgValueWithContext_(TVMValue value, int type_code, int arg_index, - const std::string* optional_name, FSig* f_sig) - : value_(value, type_code), - arg_index_(arg_index), - optional_name_(optional_name), - f_sig_(f_sig) {} - - template - operator T() const { - try { - return value_; // implicit conversion happens here - } catch (dmlc::Error& e) { - LOG(FATAL) << "In function " << (optional_name_ == nullptr ? "" : *optional_name_) - << (f_sig_ == nullptr ? "" : (*f_sig_)()) << ": error while converting argument " - << arg_index_ << ": " << e.what(); - throw; // never reached, LOG(FATAL) throws, but this silences a warning. - } - } - - private: - TVMMovableArgValue_ value_; - int arg_index_; - const std::string* optional_name_; - FSig* f_sig_; -}; - -/*! - * \brief Return Value container, - * Unlike TVMArgValue, which only holds reference and do not delete - * the underlying container during destruction. + * // The following code is OK, assuming the macro + * // is in a different namespace from xyz + * // TVM_DLL_EXPORT_TYPED_FUNC(AddOne_, xyz::AddOne_); * - * TVMRetValue holds value and will manage the underlying containers - * when it stores a complicated data type. + * \endcode */ -class TVMRetValue : public TVMPODValue_CRTP_ { - public: - /*! \brief default constructor */ - TVMRetValue() {} - /*! - * \brief move constructor from another return value. - * \param other The other return value. - */ - TVMRetValue(TVMRetValue&& other) : TVMPODValue_CRTP_(other.value_, other.type_code_) { - other.value_.v_handle = nullptr; - other.type_code_ = kTVMNullptr; - } - /*! \brief destructor */ - ~TVMRetValue() { this->Clear(); } - // reuse converter from parent - using TVMPODValue_CRTP_::operator double; - using TVMPODValue_CRTP_::operator int64_t; - using TVMPODValue_CRTP_::operator uint64_t; - using TVMPODValue_CRTP_::operator int; - using TVMPODValue_CRTP_::operator bool; - using TVMPODValue_::operator void*; - using TVMPODValue_::operator DLTensor*; - using TVMPODValue_::operator Device; - using TVMPODValue_::operator NDArray; - using TVMPODValue_::operator Module; - using TVMPODValue_::operator PackedFunc; - using TVMPODValue_CRTP_::AsObjectRef; - using TVMPODValue_CRTP_::IsObjectRef; - - TVMRetValue(const TVMRetValue& other) : TVMPODValue_CRTP_() { this->Assign(other); } - // conversion operators - operator std::string() const { - if (type_code_ == kTVMDataType) { - return DLDataType2String(operator DLDataType()); - } else if (type_code_ == kTVMBytes) { - return *ptr(); - } - TVM_CHECK_TYPE_CODE(type_code_, kTVMStr); - return *ptr(); - } - operator DLDataType() const { - if (type_code_ == kTVMStr) { - return String2DLDataType(operator std::string()); - } - TVM_CHECK_TYPE_CODE(type_code_, kTVMDataType); - return value_.v_type; - } - operator DataType() const { return DataType(operator DLDataType()); } - template - operator TypedPackedFunc() const { - return TypedPackedFunc(operator PackedFunc()); - } - // Assign operators - TVMRetValue& operator=(TVMRetValue&& other) { - this->Clear(); - value_ = other.value_; - type_code_ = other.type_code_; - other.type_code_ = kTVMNullptr; - return *this; - } - TVMRetValue& operator=(double value) { - this->SwitchToPOD(kDLFloat); - value_.v_float64 = value; - return *this; - } - TVMRetValue& operator=(std::nullptr_t value) { - this->SwitchToPOD(kTVMNullptr); - value_.v_handle = value; - return *this; - } - TVMRetValue& operator=(void* value) { - this->SwitchToPOD(kTVMOpaqueHandle); - value_.v_handle = value; - return *this; - } - TVMRetValue& operator=(int64_t value) { - this->SwitchToPOD(kDLInt); - value_.v_int64 = value; - return *this; - } - TVMRetValue& operator=(int value) { - this->SwitchToPOD(kDLInt); - value_.v_int64 = value; - return *this; - } - TVMRetValue& operator=(DLDevice value) { - this->SwitchToPOD(kDLDevice); - value_.v_device = value; - return *this; - } - TVMRetValue& operator=(DLDataType t) { - this->SwitchToPOD(kTVMDataType); - value_.v_type = t; - return *this; - } - TVMRetValue& operator=(const DataType& other) { return operator=(other.operator DLDataType()); } - TVMRetValue& operator=(bool value) { - this->SwitchToPOD(kTVMArgBool); - value_.v_int64 = value; - return *this; - } - TVMRetValue& operator=(std::string value) { - this->SwitchToClass(kTVMStr, value); - return *this; - } - TVMRetValue& operator=(TVMByteArray value) { - this->SwitchToClass(kTVMBytes, std::string(value.data, value.size)); - return *this; - } - TVMRetValue& operator=(NDArray other) { - if (other.data_ != nullptr) { - this->Clear(); - type_code_ = kTVMNDArrayHandle; - value_.v_handle = NDArray::FFIGetHandle(other); - ObjectRef::FFIClearAfterMove(&other); - } else { - SwitchToPOD(kTVMNullptr); - value_.v_handle = nullptr; - } - return *this; - } - TVMRetValue& operator=(Module m) { - SwitchToObject(kTVMModuleHandle, std::move(m.data_)); - return *this; - } - TVMRetValue& operator=(PackedFunc f) { - this->SwitchToObject(kTVMPackedFuncHandle, std::move(f.data_)); - return *this; - } - template - TVMRetValue& operator=(const TypedPackedFunc& f) { - return operator=(f.packed()); - } - TVMRetValue& operator=(const TVMRetValue& other) { // NOLINT(*0 - this->Assign(other); - return *this; - } - TVMRetValue& operator=(const TVMArgValue& other) { - this->Assign(other); - return *this; - } - TVMRetValue& operator=(TVMMovableArgValue_&& other) { - this->Assign(other); - return *this; - } - /*! - * \brief Move the value back to front-end via C API. - * This marks the current container as null. - * The managed resources are moved to the front-end. - * The front end should take charge in managing them. - * - * \param ret_value The return value. - * \param ret_type_code The return type code. - */ - void MoveToCHost(TVMValue* ret_value, int* ret_type_code) { - // cannot move str; need specially handle. - ICHECK(type_code_ != kTVMStr && type_code_ != kTVMBytes); - *ret_value = value_; - *ret_type_code = type_code_; - type_code_ = kTVMNullptr; - } - /*! - * \brief Construct a new TVMRetValue by - * moving from return value stored via C API. - * \param value the value. - * \param type_code The type code. - * \return The created TVMRetValue. - */ - static TVMRetValue MoveFromCHost(TVMValue value, int type_code) { - // Can move POD and everything under the object system. - ICHECK(type_code <= kTVMPackedFuncHandle || type_code == kTVMNDArrayHandle || - type_code == kTVMArgBool); - TVMRetValue ret; - ret.value_ = value; - ret.type_code_ = type_code; - return ret; - } - /*! \return The value field, if the data is POD */ - const TVMValue& value() const { - ICHECK(type_code_ != kTVMObjectHandle && type_code_ != kTVMPackedFuncHandle && - type_code_ != kTVMModuleHandle && type_code_ != kTVMStr) - << "TVMRetValue.value can only be used for POD data"; - return value_; - } - // ObjectRef handling - template >> - inline TVMRetValue& operator=(TObjectRef other); - template >> - inline operator T() const; - - private: - template - void Assign(const T& other) { - switch (other.type_code()) { - case kTVMStr: { - SwitchToClass(kTVMStr, other); - break; - } - case kTVMBytes: { - SwitchToClass(kTVMBytes, other); - break; - } - case kTVMPackedFuncHandle: { - *this = other.operator PackedFunc(); - break; - } - case kTVMModuleHandle: { - *this = other.operator Module(); - break; - } - case kTVMNDArrayHandle: { - *this = other.operator NDArray(); - break; - } - case kTVMObjectHandle: { - // We already known it is not NDArray/Module, but - // operator=(ObjectRef) also handles conversions from wrappers - // around primitive types. For NDArray/Module, the duplicate - // checks are removed with if constexpr. - operator=(other.operator ObjectRef()); - break; - } - case kTVMObjectRValueRefArg: { - operator=(other.operator ObjectRef()); - break; - } - default: { - SwitchToPOD(other.type_code()); - value_ = other.value_; - break; - } - } - } - // get the internal container. - void SwitchToPOD(int type_code) { - if (type_code_ != type_code) { - this->Clear(); - type_code_ = type_code; - } - } - template - void SwitchToClass(int type_code, T v) { - if (type_code_ != type_code) { - this->Clear(); - type_code_ = type_code; - value_.v_handle = new T(v); - } else { - *static_cast(value_.v_handle) = v; - } - } - void SwitchToObject(int type_code, ObjectPtr other) { - if (other.data_ != nullptr) { - this->Clear(); - type_code_ = type_code; - // move the handle out - value_.v_handle = other.data_; - other.data_ = nullptr; - } else { - SwitchToPOD(kTVMNullptr); - value_.v_handle = nullptr; - } - } - void Clear() { - if (type_code_ == kTVMNullptr) return; - switch (type_code_) { - case kTVMStr: - case kTVMBytes: - delete ptr(); - break; - case kTVMPackedFuncHandle: - static_cast(value_.v_handle)->DecRef(); - break; - case kTVMNDArrayHandle: { - NDArray::FFIDecRef(static_cast(value_.v_handle)); - break; - } - case kTVMModuleHandle: { - static_cast(value_.v_handle)->DecRef(); - break; - } - case kTVMObjectHandle: { - static_cast(value_.v_handle)->DecRef(); - break; - } - } - type_code_ = kTVMNullptr; - } -}; - -/*! - * \brief Type trait to specify special value conversion rules from - * TVMArgValue and TVMRetValue. - * - * The trait can be specialized to add type specific conversion logic - * from the TVMArgvalue and TVMRetValue. - * - * \tparam TObjectRef the specific ObjectRefType. - */ -template -struct PackedFuncValueConverter { - /*! - * \brief Convert a TObjectRef from an argument value. - * \param val The argument value. - * \return the converted result. - */ - static TObjectRef From(const TVMArgValue& val) { return val.AsObjectRef(); } - /*! - * \brief Convert a TObjectRef from a return value. - * \param val The argument value. - * \return the converted result. - */ - static TObjectRef From(const TVMRetValue& val) { return val.AsObjectRef(); } -}; - -/*! - * \brief Export a function with the PackedFunc signature - * as a PackedFunc that can be loaded by LibraryModule. - * - * \param ExportName The symbol name to be exported. - * \param Function The function with PackedFunc signature. - * \sa PackedFunc - * - * \code - * - * void AddOne_(TVMArgs args, TVMRetValue* rv) { - * int value = args[0]; - * *rv = value + 1; - * } - * // Expose the function as "AddOne" - * TVM_DLL_EXPORT_PACKED_FUNC(AddOne, AddOne_); - * - * \endcode - */ -#define TVM_DLL_EXPORT_PACKED_FUNC(ExportName, Function) \ - extern "C" { \ - TVM_DLL int ExportName(TVMValue* args, int* type_code, int num_args, TVMValue* out_value, \ - int* out_type_code, void* resource_handle); \ - int ExportName(TVMValue* args, int* type_code, int num_args, TVMValue* out_value, \ - int* out_type_code, void* resource_handle) { \ - try { \ - ::tvm::runtime::TVMRetValue rv; \ - Function(::tvm::runtime::TVMArgs(args, type_code, num_args), &rv); \ - rv.MoveToCHost(out_value, out_type_code); \ - return 0; \ - } catch (const ::std::exception& _except_) { \ - TVMAPISetLastError(_except_.what()); \ - return -1; \ - } \ - } \ - } - -#define TVM_MODULE_VTABLE_BEGIN(TypeKey) \ - const char* type_key() const final { return TypeKey; } \ - PackedFunc GetFunction(const String& _name, const ObjectPtr& _self) override { \ - using SelfPtr = std::remove_cv_t; -#define TVM_MODULE_VTABLE_END() \ - return PackedFunc(nullptr); \ - } -#define TVM_MODULE_VTABLE_END_WITH_DEFAULT(MemFunc) \ - { \ - auto f = (MemFunc); \ - return (this->*f)(_name); \ - } \ - } // NOLINT(*) -#define TVM_MODULE_VTABLE_ENTRY(Name, MemFunc) \ - if (_name == Name) { \ - return PackedFunc([_self](TVMArgs args, TVMRetValue* rv) -> void { \ - using Helper = ::tvm::runtime::detail::ModuleVTableEntryHelper; \ - SelfPtr self = static_cast(_self.get()); \ - CHECK_EQ(args.size(), Helper::LenArgs) \ - << "Function `" << self->type_key() << "::" << Name << "` requires " << Helper::LenArgs \ - << " arguments, but got " << args.size(); \ - Helper::Call(rv, self, MemFunc, args, Helper::IndexSeq{}); \ - }); \ - } -#define TVM_MODULE_VTABLE_ENTRY_PACKED(Name, MemFunc) \ - if (_name == Name) { \ - return PackedFunc([_self](TVMArgs args, TVMRetValue* rv) -> void { \ - (static_cast(_self.get())->*(MemFunc))(args, rv); \ - }); \ - } - -/*! - * \brief Export typed function as a PackedFunc - * that can be loaded by LibraryModule. - * - * \param ExportName The symbol name to be exported. - * \param Function The typed function. - * \note ExportName and Function must be different, - * see code examples below. - * - * \sa TypedPackedFunc - * - * \code - * - * int AddOne_(int x) { - * return x + 1; - * } - * - * // Expose the function as "AddOne" - * TVM_DLL_EXPORT_TYPED_FUNC(AddOne, AddOne_); - * - * // Expose the function as "SubOne" - * TVM_DLL_EXPORT_TYPED_FUNC(SubOne, [](int x) { - * return x - 1; - * }); - * - * // The following code will cause compilation error. - * // Because the same Function and ExportName - * // TVM_DLL_EXPORT_TYPED_FUNC(AddOne_, AddOne_); - * - * // The following code is OK, assuming the macro - * // is in a different namespace from xyz - * // TVM_DLL_EXPORT_TYPED_FUNC(AddOne_, xyz::AddOne_); - * - * \endcode - */ -#define TVM_DLL_EXPORT_TYPED_FUNC(ExportName, Function) \ - extern "C" { \ - TVM_DLL int ExportName(TVMValue* args, int* type_code, int num_args, TVMValue* out_value, \ - int* out_type_code, void* resource_handle) { \ - try { \ - auto f = Function; \ - using FType = ::tvm::runtime::detail::function_signature::FType; \ - ::tvm::runtime::TVMRetValue rv; \ - ::tvm::runtime::detail::unpack_call_by_signature::run( \ - f, ::tvm::runtime::TVMArgs(args, type_code, num_args), &rv); \ - rv.MoveToCHost(out_value, out_type_code); \ - return 0; \ - } catch (const ::std::exception& _except_) { \ - TVMAPISetLastError(_except_.what()); \ - return -1; \ - } \ - } \ - } - -inline TVMArgValue TVMArgs::operator[](int i) const { - ICHECK_LT(i, num_args) << "not enough argument passed, " << num_args << " passed" - << " but request arg[" << i << "]."; - return TVMArgValue(values[i], type_codes[i]); -} - -inline int TVMArgs::size() const { return num_args; } - -template -void PackedFuncObj::Extractor::Call(const PackedFuncObj* obj, TVMArgs args, - TVMRetValue* rv) { - (static_cast(obj))->callable_(args, rv); -} - -TVM_ALWAYS_INLINE void PackedFuncObj::CallPacked(TVMArgs args, TVMRetValue* rv) const { - (*f_call_packed_)(this, args, rv); -} - -TVM_ALWAYS_INLINE void PackedFunc::CallPacked(TVMArgs args, TVMRetValue* rv) const { - (static_cast(data_.get()))->CallPacked(args, rv); -} - -// internal namespace -inline const char* ArgTypeCode2Str(int type_code) { - switch (type_code) { - case kDLInt: - return "int"; - case kTVMArgBool: - return "bool"; - case kDLUInt: - return "uint"; - case kDLFloat: - return "float"; - case kTVMStr: - return "str"; - case kTVMBytes: - return "bytes"; - case kTVMOpaqueHandle: - return "handle"; - case kTVMNullptr: - return "NULL"; - case kTVMDLTensorHandle: - return "ArrayHandle"; - case kTVMDataType: - return "DLDataType"; - case kDLDevice: - return "DLDevice"; - case kTVMPackedFuncHandle: - return "FunctionHandle"; - case kTVMModuleHandle: - return "ModuleHandle"; - case kTVMNDArrayHandle: - return "NDArrayContainer"; - case kTVMObjectHandle: - return "Object"; - case kTVMObjectRValueRefArg: - return "ObjectRValueRefArg"; - default: - LOG(FATAL) << "unknown type_code=" << static_cast(type_code); - } - throw; -} - -/*! - * \brief The name of DLDeviceType. - * \param type The device type. - * \return the device name. - */ -inline const char* DLDeviceType2Str(int type) { - switch (type) { - case kDLCPU: - return "cpu"; - case kDLCUDA: - return "cuda"; - case kDLCUDAHost: - return "cuda_host"; - case kDLCUDAManaged: - return "cuda_managed"; - case kDLOpenCL: - return "opencl"; - case kDLVulkan: - return "vulkan"; - case kDLMetal: - return "metal"; - case kDLVPI: - return "vpi"; - case kDLROCM: - return "rocm"; - case kDLROCMHost: - return "rocm_host"; - case kDLExtDev: - return "ext_dev"; - case kDLOneAPI: - return "oneapi"; - case kDLWebGPU: - return "webgpu"; - case kDLHexagon: - return "hexagon"; - default: - LOG(FATAL) << "unknown type = " << type; - } - throw; -} - -namespace detail { - -template -struct for_each_dispatcher { - template - static void run(const F& f, T&& value, Args&&... args) { // NOLINT(*) - f(I, std::forward(value)); - for_each_dispatcher::run(f, std::forward(args)...); - } -}; - -template -struct for_each_dispatcher { - static void run(const F& f) {} // NOLINT(*) -}; - -template -inline void for_each(const F& f, Args&&... args) { // NOLINT(*) - for_each_dispatcher::run(f, std::forward(args)...); -} - -template -struct ModuleVTableEntryHelper {}; - -template -struct ModuleVTableEntryHelper { - using MemFnType = R (T::*)(Args...) const; - using IndexSeq = std::index_sequence_for; - static constexpr const std::size_t LenArgs = sizeof...(Args); - - template - static TVM_ALWAYS_INLINE void Call(TVMRetValue* rv, T* self, MemFnType f, TVMArgs args, - std::index_sequence) { - *rv = (self->*f)(args[Is]...); - } -}; - -template -struct ModuleVTableEntryHelper { - using MemFnType = R (T::*)(Args...); - using IndexSeq = std::index_sequence_for; - static constexpr const std::size_t LenArgs = sizeof...(Args); - - template - static TVM_ALWAYS_INLINE void Call(TVMRetValue* rv, T* self, MemFnType f, TVMArgs args, - std::index_sequence) { - *rv = (self->*f)(args[Is]...); - } -}; - -template -struct ModuleVTableEntryHelper { - using MemFnType = void (T::*)(Args...) const; - using IndexSeq = std::index_sequence_for; - static constexpr const std::size_t LenArgs = sizeof...(Args); - - template - static TVM_ALWAYS_INLINE void Call(TVMRetValue* rv, T* self, MemFnType f, TVMArgs args, - std::index_sequence) { - (self->*f)(args[Is]...); - } -}; - -template -struct ModuleVTableEntryHelper { - using MemFnType = void (T::*)(Args...); - using IndexSeq = std::index_sequence_for; - static constexpr const std::size_t LenArgs = sizeof...(Args); - - template - static TVM_ALWAYS_INLINE void Call(TVMRetValue* rv, T* self, MemFnType f, TVMArgs args, - std::index_sequence) { - (self->*f)(args[Is]...); - } -}; - -namespace parameter_pack { - -template -struct EnumeratedParamPack { - struct InvokeWithoutArg { - template